Source code for scitokens.scitokens


"""
SciTokens reference library.

This library provides the primitives necessary for working with SciTokens
authorization tokens.
"""

import time

import os
import jwt
from . import urltools
import logging

LOGGER = logging.getLogger("scitokens")
import uuid

import cryptography.hazmat.backends as backends
from .utils import keycache as KeyCache
from .utils import config
from .utils.errors import MissingIssuerException, InvalidTokenFormat, MissingKeyException, UnsupportedKeyException
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.hazmat.primitives.asymmetric import rsa, ec

[docs]class SciToken(object): """ An object representing the contents of a SciToken. """ def __init__(self, key=None, algorithm=None, key_id=None, parent=None, claims=None): """ Construct a SciToken object. :param key: Private key to sign the SciToken with. It should be the PEM contents. :param algorithm: Private key algorithm to sign the SciToken with. Default: RS256 :param str key_id: A string representing the Key ID that is used at the issuer :param parent: Parent SciToken that will be chained """ if claims is not None: raise NotImplementedError() self._key = key derived_alg = None if key: derived_alg = self._derive_algorithm(key) # Make sure we support the key algorithm if key and not algorithm and not derived_alg: # We don't know the key algorithm raise UnsupportedKeyException("Key was given for SciToken, but algorithm was not " "passed to SciToken creation and it cannot be derived " "from the provided key") elif derived_alg and not algorithm: self._key_alg = derived_alg elif derived_alg and algorithm and derived_alg != algorithm: error_str = ("Key provided reports algorithm type: {0}, ".format(derived_alg) + "while scitoken creation argument was {0}".format(algorithm)) raise UnsupportedKeyException(error_str) elif key and algorithm: self._key_alg = algorithm else: # If key is not specified, and neither is algorithm self._key_alg = algorithm if algorithm is not None else config.get('default_alg') if self._key_alg not in ["RS256", "ES256"]: raise UnsupportedKeyException() self._key_id = key_id self._parent = parent self._claims = {} self._verified_claims = {} self.insecure = False self._serialized_token = None @staticmethod def _derive_algorithm(key): """ Derive the algorithm type from the PEM contents of the key returns: Key algorithm if known, otherwise None """ if isinstance(key, rsa.RSAPrivateKey): return "RS256" elif isinstance(key, ec.EllipticCurvePrivateKey): if key.curve.name == "secp256r1": return "ES256" # If it gets here, we don't know what type of key return None
[docs] def claims(self): """ Return an iterator of (key, value) pairs of claims, starting with the claims from the first token in the chain. """ if self._parent: for claim, value in self._parent.claims(): yield claim, value for claim, value in self._verified_claims.items(): yield claim, value for claim, value in self._claims.items(): yield claim, value
[docs] def verify(self): """ Verify the claims of the in-memory token. Automatically called by deserialize. """ raise NotImplementedError()
[docs] def serialize(self, include_key=False, issuer=None, lifetime=600): """ Serialize the existing SciToken. :param bool include_key: When true, include the public key to the serialized token. Default=False :param str issuer: A string indicating the issuer for the token. It should be an HTTPS address, as specified in https://tools.ietf.org/html/draft-ietf-oauth-discovery-07 :param int lifetime: Number of seconds that the token should be valid :return bytes: base64 encoded token """ if include_key is not False: raise NotImplementedError() if self._key == None: raise MissingKeyException("Unable to serialize, missing private key") # Issuer needs to be available, otherwise throw an error if issuer is None and 'iss' not in self: raise MissingIssuerException("Issuer not specific in claims or as argument") if not issuer: issuer = self['iss'] # Set the issue and expiration time of the token issue_time = int(time.time()) exp_time = int(issue_time + lifetime) # Add to validated and other claims payload = dict(self._verified_claims) payload.update(self._claims) # Anything below will override what is in the claims payload.update({ "iss": issuer, "exp": exp_time, "iat": issue_time, "nbf": issue_time }) if 'jti' not in payload: # Create a jti from a uuid payload['jti'] = str(uuid.uuid4()) self._claims['jti'] = payload['jti'] if self._key_id != None: encoded = jwt.encode(payload, self._key, algorithm = self._key_alg, headers={'kid': self._key_id}) else: encoded = jwt.encode(payload, self._key, algorithm = self._key_alg) if isinstance(encoded, bytes): # pyjwt < 2 returns bytes encoded = encoded.decode("utf-8") self._serialized_token = encoded # Move claims over to verified claims self._verified_claims.update(self._claims) self._claims = {} global LOGGER LOGGER.info("Signed Token: {0}".format(str(payload))) # Encode the returned string for backwards compatibility. # Previous versions of PyJWT returned bytes return str.encode(encoded, encoding="utf-8")
[docs] def update_claims(self, claims): """ Add new claims to the token. :param claims: Dictionary of claims to add to the token """ self._claims.update(claims)
def __setitem__(self, claim, value): """ Assign a new claim to the token. """ self._claims[claim] = value def __getitem__(self, claim): """ Access the value corresponding to a particular claim; will return claims from both the verified and unverified claims. If a claim is not present, then a KeyError is thrown. """ if claim in self._claims: return self._claims[claim] if claim in self._verified_claims: return self._verified_claims[claim] raise KeyError(claim) def __contains__(self, claim): """ Check if the claim exists in the SciToken """ if claim in self._claims: return True if claim in self._verified_claims: return True return False def __delitem__(self, claim): """ Delete the claim from the SciToken """ deleted = False if claim in self._claims: del self._claims[claim] deleted = True if claim in self._verified_claims: del self._verified_claims[claim] deleted = True if deleted: return else: raise KeyError(claim)
[docs] def get(self, claim, default=None, verified_only=False): """ Return the value associated with a claim, returning the default if the claim is not present. If `verified_only` is True, then a claim is returned only if it is in the verified claims """ if verified_only: return self._verified_claims.get(claim, default) return self._claims.get(claim, self._verified_claims.get(claim, default))
[docs] def clone_chain(self): """ Return a new, empty SciToken """ raise NotImplementedError()
def _deserialize_key(self, key_serialized, unverified_headers): """ Given a serialized key and a set of UNVERIFIED headers, return a corresponding private key object. """
[docs] @staticmethod def deserialize(serialized_token, audience=None, require_key=False, insecure=False, public_key=None): """ Given a serialized SciToken, load it into a SciTokens object. Verifies the claims pass the current set of validation scripts. :param str serialized_token: The serialized token. :param str audience: (Legacy, not checked) The audience URI that this principle is claiming. Default: None. Audience is not checked no matter the value. :param bool require_key: When True, require the key :param bool insecure: When True, allow insecure methods to verify the issuer, including allowing "localhost" issuer (useful in testing). Default=False :param str public_key: A PEM formatted public key string to be used to validate the token """ if require_key is not False: raise NotImplementedError() if isinstance(serialized_token, bytes): serialized_token = serialized_token.decode('utf8') info = serialized_token.split(".") if len(info) != 3 and len(info) != 4: # header, format, signature[, key] raise InvalidTokenFormat("Serialized token is not a readable format.") if (len(info) != 4) and require_key: raise MissingKeyException("No key present in serialized token") serialized_jwt = info[0] + "." + info[1] + "." + info[2] unverified_headers = jwt.get_unverified_header(serialized_jwt) unverified_payload = jwt.decode(serialized_jwt, algorithms=['RS256', 'ES256'], audience=audience, options={"verify_signature": False, "verify_aud": False}) # Get the public key from the issuer keycache = KeyCache.KeyCache().getinstance() if public_key == None: if 'iss' not in unverified_payload: raise MissingIssuerException('Issuer not provided') issuer_public_key = keycache.getkeyinfo(unverified_payload['iss'], key_id=unverified_headers['kid'] if 'kid' in unverified_headers else None, insecure=insecure) else: issuer_public_key = load_pem_public_key(public_key, backend=backends.default_backend()) claims = jwt.decode(serialized_token, issuer_public_key, algorithms=['RS256', 'ES256'], options={"verify_aud": False}) to_return = SciToken() to_return._verified_claims = claims to_return._serialized_token = serialized_token return to_return
[docs] @staticmethod def discover(audience=None, require_key=False, insecure=False, public_key=None): """ Create a SciToken by looking for a token with WLCG Bearer Token Discovery protocol https://github.com/WLCG-AuthZ-WG/bearer-token-discovery/blob/master/specification.md The serialized token is read in and passed to the deserialize() method to load it into a SciTokens object. Raises IOError is a token cannot be found or the errors of SciTokens.deserialize() if there is an error reading the discovered token. :param str audience: The audience URI that this principle is claiming. Default: None :param bool require_key: When True, require the key :param bool insecure: When True, allow insecure methods to verify the issuer, including allowing "localhost" issuer (useful in testing). Default=False :param str public_key: A PEM formatted public key string to be used to validate the token """ if os.environ.get('BEARER_TOKEN'): return SciToken.deserialize(os.environ['BEARER_TOKEN'].strip(), audience, require_key, insecure, public_key) if os.environ.get('BEARER_TOKEN_FILE') and os.path.isfile(os.environ.get('BEARER_TOKEN_FILE')): with open(os.environ.get('BEARER_TOKEN_FILE')) as t: token_data = t.read().strip() return SciToken.deserialize(token_data, audience, require_key, insecure, public_key) bt_file = 'bt_u{}'.format(os.geteuid()) if os.environ.get('XDG_RUNTIME_DIR'): bt_path = os.path.join(os.environ.get('XDG_RUNTIME_DIR'), bt_file) else: bt_path = os.path.join('/tmp', bt_file) if os.path.isfile(bt_path): with open(bt_path) as t: token_data = t.read().strip() return SciToken.deserialize(token_data, audience, require_key, insecure, public_key) raise OSError( "failed to identify a valid bearer token", )
[docs]class ValidationFailure(Exception): """ Validation of a token was attempted but failed for an unknown reason. """
[docs]class NoRegisteredValidator(ValidationFailure): """ The Validator object attempted validation of a token, but encountered a claim with no registered validator. """
[docs]class ClaimInvalid(ValidationFailure): """ The Validator object attempted validation of a given claim, but one of the callbacks marked the claim as invalid. """
[docs]class MissingClaims(ValidationFailure): """ Validation failed because one or more claim marked as critical is missing from the token. """
[docs]class Validator(object): """ Validate the contents of a SciToken. Given a SciToken, validate the contents of its claims. Unlike verification, which checks that the token is correctly signed, validation provides an easy-to-use interface that ensures the claims in the token are understood by the user. """ def __init__(self): self._callbacks = {}
[docs] def add_validator(self, claim, validate_op): """ Add a validation callback for a given claim. When the given ``claim`` encountered in a token, ``validate_op`` object will be called with the following signature:: >>> validate_op(value) where ``value`` is the value of the token's claim converted to a python object. The validator should return ``True`` if the value is acceptable and ``False`` otherwise. """ validator_list = self._callbacks.setdefault(claim, []) validator_list.append(validate_op)
[docs] def validate(self, token, critical_claims=None): """ Validate the claims of a token. This will iterate through all claims in the given :class:`SciToken` and determine whether all claims a valid, given the current set of validators. If ``critical_claims`` is specified, then validation will fail if one or more claim in this list is not present in the token. This will throw an exception if the token is invalid and return ``True`` if the token is valid. """ if critical_claims: critical_claims = set(critical_claims) else: critical_claims = set() for claim, value in token.claims(): if claim in critical_claims: critical_claims.remove(claim) validator_list = self._callbacks.setdefault(claim, []) if not validator_list: if "ver" not in token or token["ver"] != "scitoken:2.0": raise NoRegisteredValidator("No validator was registered to handle encountered claim '%s'" % claim) for validator in validator_list: if not validator(value): raise ClaimInvalid("Validator rejected value of '%s' for claim '%s'" % (value, claim)) if len(critical_claims): raise MissingClaims("Validation failed because the following claims are missing: %s" % \ ", ".join(critical_claims)) return True
def __call__(self, token): return self.validate(token)
[docs]class EnforcementError(Exception): """ A generic error during the enforcement of a SciToken. """
[docs]class InvalidPathError(EnforcementError): """ An invalid test path was provided to the Enforcer object. Test paths must be absolute paths (start with '/') """
[docs]class InvalidAuthorizationResource(EnforcementError): """ A scope was encountered with an invalid authorization. Examples include: - Authorizations that require paths (read, write) but none were included. - Scopes that include relative paths (read:~/foo) """
[docs]class Enforcer(object): """ Enforce SciTokens-specific validation logic. Allows one to test if a given token has a particular authorization. This class is NOT thread safe; a separate object is needed for every thread. """ _authz_requiring_path = set(["read", "write"]) # An array of versions of scitokens that we understand and can enforce _versions_understood = [ 1, "scitoken:2.0" ] def __init__(self, issuer, audience=None): self._issuer = issuer self.last_failure = None if not self._issuer: raise EnforcementError("Issuer must be specified.") self._audience = audience self._test_access = False self._test_authz = None self._test_path = None self._token_scopes = set() self._now = 0 self._validator = Validator() self._validator.add_validator("exp", self._validate_exp) self._validator.add_validator("nbf", self._validate_nbf) self._validator.add_validator("iss", self._validate_iss) self._validator.add_validator("iat", self._validate_iat) self._validator.add_validator("aud", self._validate_aud) self._validator.add_validator("scp", self._validate_scp) self._validator.add_validator("scope", self._validate_scope) self._validator.add_validator("jti", self._validate_jti) self._validator.add_validator("sub", self._validate_sub) self._validator.add_validator("ver", self._validate_ver) self._validator.add_validator("opt", self._validate_opt) def _reset_state(self): """ Reset the internal state variables of the Enforcer object. Automatically invoked each time the Enforcer is used to test or generate_acls """ self._test_authz = None self._test_path = None self._test_access = False self._token_scopes = set() self._now = time.time() self.last_failure = None
[docs] def add_validator(self, claim, validator): """ Add a user-defined validator in addition to the default enforcer logic. """ self._validator.add_validator(claim, validator)
[docs] def test(self, token, authz, path=None): """ Test whether a given token has the requested permission within the current enforcer context. """ self._reset_state() self._test_access = True critical_claims = set(["scope"]) # Check for the older "scp" attribute if 'scope' not in token and 'scp' in token: critical_claims = set(["scp"]) # In scitokens 2.0, some claims are required if 'ver' in token and token['ver'] == "scitoken:2.0": critical_claims.update(['aud', 'ver']) if not path and (authz in self._authz_requiring_path): raise InvalidPathError("Enforcer provided with an empty path.") if path and not path.startswith("/"): raise InvalidPathError("Enforcer was given an invalid relative path to test; absolute path required.") self._test_path = path self._test_authz = authz self.last_failure = None try: self._validator.validate(token, critical_claims=critical_claims) except ValidationFailure as validation_failure: self.last_failure = str(validation_failure) return False return True
[docs] def generate_acls(self, token): """ Given a SciToken object and the expected issuer, return the valid ACLs. """ self._reset_state() critical_claims = set(["scope"]) # Check for the older "scp" attribute if 'scope' not in token and 'scp' in token: critical_claims = set(["scp"]) try: self._validator.validate(token, critical_claims=critical_claims) except ValidationFailure as verify_fail: self.last_failure = str(verify_fail) raise return list(self._token_scopes)
def _validate_exp(self, value): exp = float(value) return exp >= self._now def _validate_nbf(self, value): nbf = float(value) return nbf < self._now def _validate_iss(self, value): return self._issuer == value def _validate_iat(self, value): return float(value) < self._now def _validate_aud(self, value): if not self._audience: return False elif self._audience == "ANY": return False elif value == "ANY": return True # Convert the value and self._audience both to sets # Then perform set intersection values = [] if not isinstance(value, list): values = [value] else: values = value set_value = set(values) audiences = [] if not isinstance(self._audience, list): audiences = [self._audience] else: audiences = self._audience set_aud = set(audiences) if len(set_value.intersection(set_aud)) > 0: return True else: return False def _validate_ver(self, value): if value in self._versions_understood: return True else: return False @classmethod def _validate_opt(self, value): """ Opt is optional information. We don't know what's in there, so just return true. """ del value return True @classmethod def _validate_sub(self, value): """ SUB, or subject, should always pass. It's mostly used for identifying a tokens origin. """ # Fix for unused argument del value return True @classmethod def _validate_jti(self, value): """ JTI, or json token id, should always pass. It's mostly used for logging and auditing. """ global LOGGER LOGGER.info("Validating SciToken with jti: {0}".format(value)) return True def _check_scope(self, scope): """ Given a scope, make sure it contains a resource for scope types that require resources. Returns a tuple of the (authz, path). If path is not in the scope (and is not required to be explicitly inside the scope), it will default to '/'. """ info = scope.split(":", 1) authz = info[0] if authz in self._authz_requiring_path and (len(info) == 1): raise InvalidAuthorizationResource("Token contains an authorization requiring a resource" "(path), but no path is present") if len(info) == 2: path = info[1] if not path.startswith("/"): raise InvalidAuthorizationResource("Token contains a relative path in scope") norm_path = urltools.normalize_path(path) else: norm_path = '/' return (authz, norm_path) def _validate_scp(self, value): if not isinstance(value, list): value = [value] if self._test_access: if not self._test_path: norm_requested_path = '/' else: norm_requested_path = urltools.normalize_path(self._test_path) for scope in value: authz, norm_path = self._check_scope(scope) if (self._test_authz == authz) and norm_requested_path.startswith(norm_path): return True return False else: for scope in value: authz, norm_path = self._check_scope(scope) self._token_scopes.add((authz, norm_path)) return True def _validate_scope(self, value): if not isinstance(value, str): raise InvalidAuthorizationResource("Scope is invalid. Must be a space separated string") if self._test_access: if not self._test_path: norm_requested_path = '/' else: norm_requested_path = urltools.normalize_path(self._test_path) # Split on spaces for scope in value.split(" "): authz, norm_path = self._check_scope(scope) if (self._test_authz == authz) and norm_requested_path.startswith(norm_path): return True return False else: # Split on spaces for scope in value.split(" "): authz, norm_path = self._check_scope(scope) self._token_scopes.add((authz, norm_path)) return True