Source code for scitokens.utils.keycache


"""
A module for effectively caching the public keys of various token issuer endpoints.
"""

import os
import sqlite3
import time
import re
import logging
from urllib.error import URLError

try:
    import urllib.request as request
except ImportError:
    import urllib2 as request

try:
    import urlparse
except ImportError:
    import urllib.parse as urlparse

import json

from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, load_pem_public_key
import cryptography.hazmat.backends as backends
import cryptography.hazmat.primitives.asymmetric.ec as ec
import cryptography.hazmat.primitives.asymmetric.rsa as rsa
from scitokens.utils.errors import SciTokensException, MissingKeyException, NonHTTPSIssuer, UnableToCreateCache, UnsupportedKeyException
from scitokens.utils import long_from_bytes
import scitokens.utils.config as config
from cryptography.hazmat.primitives import serialization
from urllib.error import URLError


CACHE_FILENAME = "scitokens_keycache.sqllite"
KEYCACHE_INSTANCE = None

[docs] class UnableToWriteKeyCache(SciTokensException): """ For whatever reason, unable to write to the Key Cache """ pass
[docs] class KeyCache(object): """ Object that persistently caches signing keys associated with a token issuer endpoint. """ def __init__(self): # Check for the cache self.cache_location = self._get_cache_file()
[docs] @staticmethod def getinstance(): """ Return the singleton instance of the KeyCache. """ global KEYCACHE_INSTANCE if KEYCACHE_INSTANCE is None: KEYCACHE_INSTANCE = KeyCache() return KEYCACHE_INSTANCE
[docs] def addkeyinfo(self, issuer, key_id, public_key, cache_timer=0, next_update=0): """ Add a single, known public key to the cache. :param str issuer: URI of the issuer :param str key_id: Key Identifier :param public_key: Cryptography public_key object :param int cache_timer: Cache lifetime of the public_key :param int next_update: Seconds until next update time """ # If the next_update is 0, then set it to 1 hour if next_update == 0: next_update = 3600 try: conn = sqlite3.connect(self.cache_location) conn.row_factory = sqlite3.Row curs = conn.cursor() curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id]) KeyCache._addkeyinfo(curs, issuer, key_id, public_key, cache_timer=cache_timer, next_update=next_update) conn.commit() conn.close() except Exception as ex: logger = logging.getLogger("scitokens") logger.error(f'Keycache file is immutable. Detailed error: {ex}')
@staticmethod def _addkeyinfo(curs, issuer, key_id, public_key, cache_timer=0, next_update=0): """ Given an open database cursor to a key cache, insert a key. """ # Add the key to the cache insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES(?, ?, ?, ?, ?)" keydata = { 'pub_key': public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode('ascii'), } curs.execute(insert_key_statement, [issuer, time.time()+cache_timer, key_id, json.dumps(keydata), time.time()+next_update]) if curs.rowcount != 1: raise UnableToWriteKeyCache("Unable to insert into key cache") def _parse_key_data(self, issuer, kid, keydata): """ Keydata is stored as a JSON object inside the DB. Therefore, we must extract it. :param str issuer: Token Issuer in keydata :param str kid: Key ID :param str keydata: Raw JSON key data (at least, it should be) :param curs: SQLite cursor, in case it has to delete the row :returns str: encoded public key, otherwise None """ # First, get the key data try: return json.loads(keydata)['pub_key'] except ValueError: logging.exception("Unable to parse JSON stored in keycache. " "This likely means the database format needs" "to be updated, which we will now do automatically") self._delete_cache_entry(issuer, kid) return None def _delete_cache_entry(self, issuer, key_id): """ Delete a cache entry """ # Open the connection to the database try: conn = sqlite3.connect(self.cache_location) curs = conn.cursor() curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id]) conn.commit() conn.close() except Exception as ex: logger = logging.getLogger("scitokens") logger.error(f'Keycache file is immutable. Detailed error: {ex}') def _add_negative_cache_entry(self, issuer, key_id, cache_retry_interval): """ Add a negative cache entry """ try: conn = sqlite3.connect(self.cache_location) conn.row_factory = sqlite3.Row curs = conn.cursor() insert_key_statement = "INSERT OR REPLACE INTO keycache VALUES(?, ?, ?, ?, ?)" keydata = '' curs.execute(insert_key_statement, [issuer, time.time()+cache_retry_interval, key_id, keydata, time.time()+cache_retry_interval]) if curs.rowcount != 1: logger = logging.getLogger("scitokens") logger.error(UnableToWriteKeyCache("Unable to insert into key cache")) conn.commit() conn.close() except Exception as ex: logger = logging.getLogger("scitokens") logger.error(f'Keycache file is immutable. Detailed error: {ex}') def _download_and_add_key(self, issuer, key_id, insecure, force_refresh, cache_retry_interval): """ Download key data and add key (if possible) """ logger = logging.getLogger("scitokens") try: public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure) except ValueError as ex: logger.error(ex) raise ex except URLError as ex: logger.error("Unable to get key from issuer.\n{0}".format(str(ex))) raise ex except Exception as ex: logger.error("No key was found in keycache and unable to get key: {0}".format(str(ex))) # Create negative cache if not force_refresh: # If NOT forced, create negative cache try: self._add_negative_cache_entry(issuer, key_id, cache_retry_interval) except Exception as ex: logger.error(ex) raise MissingKeyException(ex) # Separate download and add key to avoid keycache deadlocks try: self.addkeyinfo(issuer, key_id, public_key, cache_timer) except Exception as ex: logger.error("Unable to add new key data to keycache.\n{0}".format(ex)) return public_key
[docs] def getkeyinfo(self, issuer, key_id=None, insecure=False, force_refresh=False, cache_retry_interval=300): """ Get the key information :param str issuer: The issuer URI :param str key_id: Text key id to identify the key :param bool insecure: Whether insecure methods are acceptable (defaults to False). :returns: None if no key is found. Else, returns the public key """ # Setup log configuration logger = logging.getLogger("scitokens") # Check the sql database if key_id is not None: key_query = "SELECT * FROM keycache WHERE issuer = ? AND key_id = ?" query_params = [issuer, key_id] else: key_query = "SELECT * FROM keycache WHERE issuer = ?" query_params = [issuer] row = None try: conn = sqlite3.connect(self.cache_location) conn.row_factory = sqlite3.Row curs = conn.cursor() curs.execute(key_query, query_params) row = curs.fetchone() conn.commit() conn.close() except Exception as ex: logger.error(f'Keycache file is immutable. Detailed error: {ex}') if row != None: # Check if record is negative cache if row['keydata'] == '': # Negative Cache Handling if not force_refresh and row['next_update'] > time.time(): logger.warning("Retry in {} seconds".format(int(row['next_update'] - time.time()))) return None else: # Force refresh or cache_retry_interval is over self._delete_cache_entry(row['issuer'], row['key_id']) row = None # If it's time to update the key, but the key is still valid if row and int(row['next_update']) < time.time() and self._check_validity(row): # Try to update the key, but if it doesn't work, just return the saved one try: # Get the public key, probably from a webserver public_key, cache_timer = self._get_issuer_publickey(issuer, key_id, insecure) # Get the sqllite connection again self.addkeyinfo(issuer, key_id, public_key, cache_timer) return public_key except Exception as ex: logger.warning("Unable to get key triggered by next update: {0}".format(str(ex))) keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata']) # Upgrade proof if keydata: return load_pem_public_key(keydata.encode(), backend=backends.default_backend()) # If it's not time to update the key, but the key is still valid elif row and self._check_validity(row): # If force_refresh is set, then update the key if force_refresh: public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval) keydata = self._parse_key_data(row['issuer'], row['key_id'], row['keydata']) if keydata: return load_pem_public_key(keydata.encode(), backend=backends.default_backend()) # If local key not valid, update the keycache public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval) return public_key # If it's not time to update the key, and the key is not valid elif row: # Delete the row # If it gets to this point, then there is a row for the key, but it's: # - Not valid anymore self._delete_cache_entry(row['issuer'], row['key_id']) # If key is a negative cache # If it reaches here, then no key was found in the SQL public_key = self._download_and_add_key(issuer, key_id, insecure, force_refresh, cache_retry_interval) return public_key
@classmethod def _check_validity(cls, key_info): """ Check the key to see if it has expired """ # Make sure the key hasn't expired if key_info['expiration'] <= time.time(): return False else: return True @staticmethod def _get_issuer_publickey(issuer, key_id=None, insecure=False): """ :return: Tuple containing (public_key, cache_lifetime). Cache_lifetime how the public key is valid """ # Set the user agent so Cloudflare isn't mad at us # Import the __version__ value in scitokens for the scitokens version from scitokens import __version__ as PKG_VERSION headers={'User-Agent' : 'SciTokens/{}'.format(PKG_VERSION)} # Go to the issuer's website, and download the OAuth well known bits # https://tools.ietf.org/html/draft-ietf-oauth-discovery-07 well_known_uri = ".well-known/openid-configuration" if not issuer.endswith("/"): issuer = issuer + "/" parsed_url = urlparse.urlparse(issuer) updated_url = urlparse.urljoin(parsed_url.path, well_known_uri) parsed_url_list = list(parsed_url) parsed_url_list[2] = updated_url meta_uri = urlparse.urlunparse(parsed_url_list) # Make sure the protocol is https if not insecure: parsed_url = urlparse.urlparse(meta_uri) if parsed_url.scheme != "https": raise NonHTTPSIssuer("Issuer is not over HTTPS. RFC requires it to be over HTTPS") response = request.urlopen(request.Request(meta_uri, headers=headers)) data = json.loads(response.read().decode('utf-8')) # Get the keys URL from the openid-configuration jwks_uri = data['jwks_uri'] # Now, get the keys if not insecure: parsed_url = urlparse.urlparse(jwks_uri) if parsed_url.scheme != "https": raise NonHTTPSIssuer("jwks_uri is not over HTTPS, insecure!") response = request.urlopen(request.Request(jwks_uri, headers=headers)) # Get the cache data from the headers cache_timer = 0 headers = response.info() if "Cache-Control" in headers: # Parse out the max-age, if it's there. if "max-age" in headers['Cache-Control']: match = re.search(r".*max-age=(\d+)", headers['Cache-Control']) if match: cache_timer = int(match.group(1)) # Minimum cache time of 10 minutes, no matter what the remote says cache_timer = max(cache_timer, config.get_int("cache_lifetime")) keys_data = json.loads(response.read().decode('utf-8')) # Loop through each key, looking for the right key id public_key = "" raw_key = None # If there is no kid in the header, then just take the first key? if key_id == None: if len(keys_data['keys']) != 1: raise NotImplementedError("No kid in header, but multiple keys in " "response from certs server. Don't know which key to use!") else: raw_key = keys_data['keys'][0] else: # Find the right key for key in keys_data['keys']: if key['kid'] == key_id: raw_key = key break if raw_key == None: raise MissingKeyException("Unable to find key at issuer {}".format(jwks_uri)) if raw_key['kty'] == "RSA": public_key_numbers = rsa.RSAPublicNumbers( long_from_bytes(raw_key['e']), long_from_bytes(raw_key['n']) ) public_key = public_key_numbers.public_key(backends.default_backend()) elif raw_key['kty'] == 'EC': public_key_numbers = ec.EllipticCurvePublicNumbers( long_from_bytes(raw_key['x']), long_from_bytes(raw_key['y']), ec.SECP256R1() ) public_key = public_key_numbers.public_key(backends.default_backend()) else: raise UnsupportedKeyException("SciToken signed with an unsupported key type") return public_key, cache_timer def _get_cache_file(self): """ Get the Cache file location 1. Configuration cache location 2. $XDG_CACHE_HOME 3. .cache subdirectory of home directory as returned by the password database """ logger = logging.getLogger("scitokens") config_cache_location = config.get('cache_location') xdg_cache_home = os.environ.get("XDG_CACHE_HOME", None) home_dir = os.path.expanduser("~") if config_cache_location != "": cache_dir = config_cache_location elif xdg_cache_home != None: cache_dir = xdg_cache_home elif home_dir != None: cache_dir = os.path.join(home_dir, ".cache") if not os.path.exists(cache_dir): try: os.makedirs(cache_dir) except OSError as ose: # Unable to create a cache is not a fatal error logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose))) # If we couldn't create the cache directory, just return, nothing more to do here return None keycache_dir = os.path.join(cache_dir, "scitokens") try: if not os.path.exists(keycache_dir): os.makedirs(keycache_dir) except OSError as ose: # Unable to create directories is not a fatal error logger.warning("Unable to create cache directory at {}: {}".format(cache_dir, str(ose))) return None keycache_file = os.path.join(keycache_dir, CACHE_FILENAME) if not os.path.exists(keycache_file): self._initialize_cachedb(keycache_file) return keycache_file @staticmethod def _initialize_cachedb(sql_file): """ Create a simple flat sqllite cache """ conn = sqlite3.connect(sql_file) curs = conn.cursor() # Create cache table curs.execute ("CREATE TABLE keycache (" "issuer text NOT NULL," "expiration integer NOT NULL," "key_id text," "keydata text NOT NULL," "next_update integer NOT NULL," "PRIMARY KEY (issuer, key_id))") # Save (commit) the changes conn.commit() # We can also close the connection if we are done with it. # Just be sure any changes have been committed or they will be lost. conn.close()
[docs] def list_keys(self): """ List all keys in keycache """ conn = sqlite3.connect(self.cache_location) curs = conn.cursor() res = curs.execute("SELECT issuer, DATETIME(expiration, 'unixepoch'), key_id, keydata, DATETIME(next_update, 'unixepoch') FROM keycache") tokens = res.fetchall() conn.close() return tokens
[docs] def remove_key(self, issuer, key_id): """ Remove a specific key from keycache """ conn = sqlite3.connect(self.cache_location) curs = conn.cursor() res = curs.execute("SELECT * FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id]) if res.fetchone() is None: conn.close() return False res = curs.execute("DELETE FROM keycache WHERE issuer = ? AND key_id = ?", [issuer, key_id]) res = curs.fetchall() conn.commit() conn.close() return True
[docs] def add_key(self, issuer, key_id, force_refresh=False): """ Add a key or update an existing one in keycache """ pubkey = self.getkeyinfo(issuer, key_id, force_refresh=force_refresh) if pubkey is None: return None pubkey_pem = pubkey.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) return pubkey_pem
[docs] def update_all_keys(self, force_refresh=False): """ Update all keys in keycache If force_refresh is True, we refresh all keys regardless of update time """ conn = sqlite3.connect(self.cache_location) curs = conn.cursor() res = curs.execute("SELECT issuer, key_id FROM keycache") tokens = res.fetchall() conn.close() res = [] for issuer, key_id in tokens: try: updated = self.add_key(issuer, key_id, force_refresh=force_refresh) res.append(updated) except Exception as ex: logger = logging.getLogger("scitokens") logger.error("Unable to update key: {0} {1}".format(issuer, key_id)) logger.error(ex) return res