"""
A module for effectively caching the public keys of various token issuer endpoints.
"""
import os
import sqlite3
import time
import re
import logging
from contextlib import closing
from urllib.error import URLError
try:
import urllib.request as request
except ImportError:
import urllib2 as request
try:
import urlparse
except ImportError:
import urllib.parse as urlparse
import json
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, load_pem_public_key
import cryptography.hazmat.backends as backends
import cryptography.hazmat.primitives.asymmetric.ec as ec
import cryptography.hazmat.primitives.asymmetric.rsa as rsa
from scitokens.utils.errors import SciTokensException, MissingKeyException, NonHTTPSIssuer, UnableToCreateCache, UnsupportedKeyException
from scitokens.utils import long_from_bytes
import scitokens.utils.config as config
from cryptography.hazmat.primitives import serialization
from urllib.error import URLError
CACHE_FILENAME = "scitokens_keycache.sqllite"
KEYCACHE_INSTANCE = None
[docs]
class UnableToWriteKeyCache(SciTokensException):
"""
For whatever reason, unable to write to the Key Cache
"""
pass
[docs]
class KeyCache(object):
"""
Object that persistently caches signing keys associated with a token issuer endpoint.
"""
def __init__(self):
# Check for the cache
self.cache_location = self._get_cache_file()
[docs]
@staticmethod
def getinstance():
"""
Return the singleton instance of the KeyCache.
"""
global KEYCACHE_INSTANCE
if KEYCACHE_INSTANCE is None:
KEYCACHE_INSTANCE = KeyCache()
return KEYCACHE_INSTANCE
[docs]
def addkeyinfo(self, issuer, key_id, public_key, cache_timer=0, next_update=0):
"""
Add a single, known public key to the cache.
:param str issuer: URI of the issuer
:param str key_id: Key Identifier
:param public_key: Cryptography public_key object
:param int cache_timer: Cache lifetime of the public_key
:param int next_update: Seconds until next update time
"""
# If the next_update is 0, then set it to 1 hour
if next_update == 0:
next_update = 3600
try:
with closing(sqlite3.connect(self.cache_location)) as conn:
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id])
KeyCache._addkeyinfo(curs, issuer, key_id, public_key, cache_timer=cache_timer, next_update=next_update)
conn.commit()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')
@staticmethod
def _addkeyinfo(curs, issuer, key_id, public_key, cache_timer=0, next_update=0):
"""
Given an open database cursor to a key cache, insert a key.
"""
# Add the key to the cache
insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES(?, ?, ?, ?, ?)"
keydata = {
'pub_key': public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode('ascii'),
}
curs.execute(insert_key_statement, [issuer, time.time()+cache_timer, key_id,
json.dumps(keydata), time.time()+next_update])
if curs.rowcount != 1:
raise UnableToWriteKeyCache("Unable to insert into key cache")
def _parse_key_data(self, issuer, kid, keydata):
"""
Keydata is stored as a JSON object inside the DB. Therefore, we must extract it.
:param str issuer: Token Issuer in keydata
:param str kid: Key ID
:param str keydata: Raw JSON key data (at least, it should be)
:param curs: SQLite cursor, in case it has to delete the row
:returns str: encoded public key, otherwise None
"""
# First, get the key data
try:
return json.loads(keydata)['pub_key']
except ValueError:
logging.exception("Unable to parse JSON stored in keycache. "
"This likely means the database format needs"
"to be updated, which we will now do automatically")
self._delete_cache_entry(issuer, kid)
return None
def _delete_cache_entry(self, issuer, key_id):
"""
Delete a cache entry
"""
# Open the connection to the database
try:
with closing(sqlite3.connect(self.cache_location)) as conn:
curs = conn.cursor()
curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id])
conn.commit()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')
def _add_negative_cache_entry(self, issuer, key_id, cache_retry_interval):
"""
Add a negative cache entry
"""
try:
with closing(sqlite3.connect(self.cache_location)) as conn:
conn.row_factory = sqlite3.Row
curs = conn.cursor()
insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES(?, ?, ?, ?, ?)"
keydata = ''
curs.execute(insert_key_statement, [issuer, time.time()+cache_retry_interval, key_id,
keydata, time.time()+cache_retry_interval])
if curs.rowcount != 1:
logger = logging.getLogger("scitokens")
logger.error(UnableToWriteKeyCache("Unable to insert into key cache"))
conn.commit()
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error(f'Keycache file is immutable. Detailed error: {ex}')
def _download_and_add_key(self, issuer, key_id, insecure, force_refresh, cache_retry_interval):
"""
Download key data and add key (if possible)
"""
logger = logging.getLogger("scitokens")
try:
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
except ValueError as ex:
logger.error(ex)
raise ex
except URLError as ex:
logger.error("Unable to get key from issuer.\n{0}".format(str(ex)))
raise ex
except Exception as ex:
logger.error("No key was found in keycache and unable to get key: {0}".format(str(ex)))
# Create negative cache
if not force_refresh:
# If NOT forced, create negative cache
try:
self._add_negative_cache_entry(issuer, key_id, cache_retry_interval)
except Exception as ex:
logger.error(ex)
raise MissingKeyException(ex)
# Separate download and add key to avoid keycache deadlocks
try:
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
except Exception as ex:
logger.error("Unable to add new key data to keycache.\n{0}".format(ex))
return public_key
[docs]
def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False, cache_retry_interval=300):
"""
Get the key information
:param str issuer: The issuer URI
:param str key_id: Text key id to identify the key
:param bool insecure: Whether insecure methods are acceptable (defaults to False).
:returns: None if no key is found. Else, returns the public key
"""
# Setup log configuration
logger = logging.getLogger("scitokens")
# Check the sql database
if key_id is not None:
key_query = "SELECT * FROM keycache WHERE issuer = ? AND key_id = ?"
query_params = [issuer, key_id]
else:
key_query = "SELECT * FROM keycache WHERE issuer = ?"
query_params = [issuer]
row = None
try:
with closing(sqlite3.connect(self.cache_location)) as conn:
conn.row_factory = sqlite3.Row
curs = conn.cursor()
curs.execute(key_query, query_params)
row = curs.fetchone()
conn.commit()
except Exception as ex:
logger.error(f'Keycache file is immutable. Detailed error: {ex}')
if row != None:
# Check if record is negative cache
if row['keydata'] == '':
# Negative Cache Handling
if not force_refresh and row['next_update'] > time.time():
logger.warning("Retry in {} seconds".format(int(row['next_update'] - time.time())))
return None
else:
# Force refresh or cache_retry_interval is over
self._delete_cache_entry(row['issuer'], row['key_id'])
row = None
# If it's time to update the key, but the key is still valid
if row and int(row['next_update']) < time.time() and self._check_validity(row):
# Try to update the key, but if it doesn't work, just return the saved one
try:
# Get the public key, probably from a webserver
public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure)
# Get the sqllite connection again
self.addkeyinfo(issuer, key_id, public_key, cache_timer)
return public_key
except Exception as ex:
logger.warning("Unable to get key triggered by next update: {0}".format(str(ex)))
keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata'])
# Upgrade proof
if keydata:
return load_pem_public_key(keydata.encode(), backend=backends.default_backend())
# If it's not time to update the key, but the key is still valid
elif row and self._check_validity(row):
# If force_refresh is set, then update the key
if force_refresh:
public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)
keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata'])
if keydata:
return load_pem_public_key(keydata.encode(), backend=backends.default_backend())
# If local key not valid, update the keycache
public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)
return public_key
# If it's not time to update the key, and the key is not valid
elif row:
# Delete the row
# If it gets to this point, then there is a row for the key, but it's:
# - Not valid anymore
self._delete_cache_entry(row['issuer'], row['key_id'])
# If key is a negative cache
# If it reaches here, then no key was found in the SQL
public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval)
return public_key
@classmethod
def _check_validity(cls, key_info):
"""
Check the key to see if it has expired
"""
# Make sure the key hasn't expired
if key_info['expiration'] <= time.time():
return False
else:
return True
@staticmethod
def _get_issuer_publickey(issuer, key_id=None, insecure=False):
"""
:return: Tuple containing (public_key, cache_lifetime). Cache_lifetime how
the public key is valid
"""
# Set the user agent so Cloudflare isn't mad at us
# Import the __version__ value in scitokens for the scitokens version
from scitokens import __version__ as PKG_VERSION
headers={'User-Agent' : 'SciTokens/{}'.format(PKG_VERSION)}
# Go to the issuer's website, and download the OAuth well known bits
# https://tools.ietf.org/html/draft-ietf-oauth-discovery-07
well_known_uri = ".well-known/openid-configuration"
if not issuer.endswith("/"):
issuer = issuer + "/"
parsed_url = urlparse.urlparse(issuer)
updated_url = urlparse.urljoin(parsed_url.path, well_known_uri)
parsed_url_list = list(parsed_url)
parsed_url_list[2] = updated_url
meta_uri = urlparse.urlunparse(parsed_url_list)
# Make sure the protocol is https
if not insecure:
parsed_url = urlparse.urlparse(meta_uri)
if parsed_url.scheme != "https":
raise NonHTTPSIssuer("Issuer is not over HTTPS. RFC requires it to be over HTTPS")
response = request.urlopen(request.Request(meta_uri, headers=headers))
data = json.loads(response.read().decode('utf-8'))
# Get the keys URL from the openid-configuration
jwks_uri = data['jwks_uri']
# Now, get the keys
if not insecure:
parsed_url = urlparse.urlparse(jwks_uri)
if parsed_url.scheme != "https":
raise NonHTTPSIssuer("jwks_uri is not over HTTPS, insecure!")
response = request.urlopen(request.Request(jwks_uri, headers=headers))
# Get the cache data from the headers
cache_timer = 0
headers = response.info()
if "Cache-Control" in headers:
# Parse out the max-age, if it's there.
if "max-age" in headers['Cache-Control']:
match = re.search(r".*max-age=(\d+)", headers['Cache-Control'])
if match:
cache_timer = int(match.group(1))
# Minimum cache time of 10 minutes, no matter what the remote says
cache_timer = max(cache_timer, config.get_int("cache_lifetime"))
keys_data = json.loads(response.read().decode('utf-8'))
# Loop through each key, looking for the right key id
public_key = ""
raw_key = None
# If there is no kid in the header, then just take the first key?
if key_id == None:
if len(keys_data['keys']) != 1:
raise NotImplementedError("No kid in header, but multiple keys in "
"response from certs server. Don't know which key to use!")
else:
raw_key = keys_data['keys'][0]
else:
# Find the right key
for key in keys_data['keys']:
if key['kid'] == key_id:
raw_key = key
break
if raw_key == None:
raise MissingKeyException("Unable to find key at issuer {}".format(jwks_uri))
if raw_key['kty'] == "RSA":
public_key_numbers = rsa.RSAPublicNumbers(
long_from_bytes(raw_key['e']),
long_from_bytes(raw_key['n'])
)
public_key = public_key_numbers.public_key(backends.default_backend())
elif raw_key['kty'] == 'EC':
public_key_numbers = ec.EllipticCurvePublicNumbers(
long_from_bytes(raw_key['x']),
long_from_bytes(raw_key['y']),
ec.SECP256R1()
)
public_key = public_key_numbers.public_key(backends.default_backend())
else:
raise UnsupportedKeyException("SciToken signed with an unsupported key type")
return public_key, cache_timer
def _get_cache_file(self):
"""
Get the Cache file location
1. Configuration cache location
2. $XDG_CACHE_HOME
3. .cache subdirectory of home directory as returned by the password database
"""
logger = logging.getLogger("scitokens")
config_cache_location = config.get('cache_location')
xdg_cache_home = os.environ.get("XDG_CACHE_HOME", None)
home_dir = os.path.expanduser("~")
if config_cache_location != "":
cache_dir = config_cache_location
elif xdg_cache_home != None:
cache_dir = xdg_cache_home
elif home_dir != None:
cache_dir = os.path.join(home_dir, ".cache")
if not os.path.exists(cache_dir):
try:
os.makedirs(cache_dir)
except OSError as ose:
# Unable to create a cache is not a fatal error
logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose)))
# If we couldn't create the cache directory, just return, nothing more to do here
return None
keycache_dir = os.path.join(cache_dir, "scitokens")
try:
if not os.path.exists(keycache_dir):
os.makedirs(keycache_dir)
except OSError as ose:
# Unable to create directories is not a fatal error
logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose)))
return None
keycache_file = os.path.join(keycache_dir, CACHE_FILENAME)
if not os.path.exists(keycache_file):
self._initialize_cachedb(keycache_file)
return keycache_file
@staticmethod
def _initialize_cachedb(sql_file):
"""
Create a simple flat sqllite cache
"""
with closing(sqlite3.connect(sql_file)) as conn:
curs = conn.cursor()
# Create cache table
curs.execute ("CREATE TABLE keycache ("
"issuer text NOT NULL,"
"expiration integer NOT NULL,"
"key_id text,"
"keydata text NOT NULL,"
"next_update integer NOT NULL,"
"PRIMARY KEY (issuer, key_id))")
# Save (commit) the changes
conn.commit()
[docs]
def list_keys(self):
"""
List all keys in keycache
"""
with closing(sqlite3.connect(self.cache_location)) as conn:
curs = conn.cursor()
res = curs.execute("SELECT issuer, DATETIME(expiration, 'unixepoch'), key_id, keydata, DATETIME(next_update, 'unixepoch') FROM keycache")
tokens = res.fetchall()
return tokens
[docs]
def remove_key(self, issuer, key_id):
"""
Remove a specific key from keycache
"""
with closing(sqlite3.connect(self.cache_location)) as conn:
curs = conn.cursor()
res = curs.execute("SELECT * FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id])
if res.fetchone() is None:
return False
res = curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id])
res = curs.fetchall()
conn.commit()
return True
[docs]
def add_key(self, issuer, key_id, force_refresh=False):
"""
Add a key or update an existing one in keycache
"""
pubkey = self.getkeyinfo(issuer, key_id, force_refresh=force_refresh)
if pubkey is None:
return None
pubkey_pem = pubkey.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pubkey_pem
[docs]
def update_all_keys(self, force_refresh=False):
"""
Update all keys in keycache
If force_refresh is True, we refresh all keys regardless of update time
"""
with closing(sqlite3.connect(self.cache_location)) as conn:
curs = conn.cursor()
res = curs.execute("SELECT issuer, key_id FROM keycache")
tokens = res.fetchall()
res = []
for issuer, key_id in tokens:
try:
updated = self.add_key(issuer, key_id, force_refresh=force_refresh)
res.append(updated)
except Exception as ex:
logger = logging.getLogger("scitokens")
logger.error("Unable to update key: {0} {1}".format(issuer, key_id))
logger.error(ex)
return res