Source code for saml2.utils

# -*- coding: utf-8 -*-

# Copyright (c) 2014, OneLogin, Inc.
# All rights reserved.

import base64
from datetime import datetime
import calendar
from hashlib import sha1
from isodate import parse_duration as duration_parser
from lxml import etree
from lxml.etree import ElementBase
from os.path import basename, dirname, join
import re
from sys import stderr
from tempfile import NamedTemporaryFile
from textwrap import wrap
from urllib import quote_plus
from uuid import uuid4
from xml.dom.minidom import Document, parseString, Element
from xml.etree.ElementTree import tostring
import zlib

import dm.xmlsec.binding as xmlsec
from dm.xmlsec.binding.tmpl import EncData, Signature
from M2Crypto import X509

from saml2.constants import OneLogin_Saml2_Constants
from saml2.errors import OneLogin_Saml2_Error


def _(msg):
    # Fixme Add i18n support
    return msg


[docs]class OneLogin_Saml2_Utils: @staticmethod
[docs] def decode_base64_and_inflate(value): """ base64 decodes and then inflates according to RFC1951 :param value: a deflated and encoded string :return: the string after decoding and inflating """ return zlib.decompress(base64.b64decode(value), -15)
@staticmethod
[docs] def deflate_and_base64_encode(value): """ Deflates and the base64 encodes a string :param value: The string to deflate and encode :return: The deflated and encoded string """ return base64.b64encode(zlib.compress(value)[2:-4])
@staticmethod
[docs] def validate_xml(xml, schema, debug=False): """ """ assert (isinstance(xml, basestring) or isinstance(xml, Document)) assert isinstance(schema, basestring) if isinstance(xml, Document): xml = xml.toxml() # Switch to lxml for schema validation try: dom = etree.fromstring(xml) except Exception: return 'unloaded_xml' schema_file = join(dirname(__file__), 'schemas', schema) f = open(schema_file, 'r') schema_doc = etree.parse(f) f.close() xmlschema = etree.XMLSchema(schema_doc) if not xmlschema.validate(dom): xml_errors = [xmlschema.error_log] if debug: stderr.write('Errors validating the metadata') stderr.write(':\n\n') for error in xml_errors: stderr.write('%s\n' % error.message) return 'invalid_xml' return parseString(etree.tostring(dom))
@staticmethod
[docs] def format_cert(cert, heads=True): """ Returns a x509 cert (adding header & footer if required). :param cert: A x509 unformated cert :type: string :param heads: True if we want to include head and footer :type: boolean :returns: Formated cert :rtype: string """ x509_cert = cert.replace('\x0D', '') x509_cert = x509_cert.replace('\r', '') x509_cert = x509_cert.replace('\n', '') if len(x509_cert) > 0: x509_cert = x509_cert.replace('-----BEGIN CERTIFICATE-----', '') x509_cert = x509_cert.replace('-----END CERTIFICATE-----', '') x509_cert = x509_cert.replace(' ', '') if heads: x509_cert = '-----BEGIN CERTIFICATE-----\n' + '\n'.join(wrap(x509_cert, 64)) + '\n-----END CERTIFICATE-----\n' return x509_cert
@staticmethod
[docs] def redirect(url, parameters={}, request_data={}): """ Executes a redirection to the provided url (or return the target url). :param url: The target url :type: string :param parameters: Extra parameters to be passed as part of the url :type: dict :param request_data: The request as a dict :type: dict :returns: Url :rtype: string """ assert isinstance(url, basestring) assert isinstance(parameters, dict) if url.startswith('/'): url = '%s%s' % (OneLogin_Saml2_Utils.get_self_url_host(request_data), url) # Verify that the URL is to a http or https site. if re.search('^https?://', url) is None: raise OneLogin_Saml2_Error( 'Redirect to invalid URL: ' + url, OneLogin_Saml2_Error.REDIRECT_INVALID_URL ) # Add encoded parameters if url.find('?') < 0: param_prefix = '?' else: param_prefix = '&' for name, value in parameters.items(): if value is None: param = urlencode(name) elif isinstance(value, list): param = '' for val in value: param += quote_plus(name) + '[]=' + quote_plus(val) + '&' if len(param) > 0: param = param[0:-1] else: param = quote_plus(name) + '=' + quote_plus(value) url += param_prefix + param param_prefix = '&' return url
@staticmethod
[docs] def get_self_url_host(request_data): """ Returns the protocol + the current host + the port (if different than common ports). :param request_data: The request as a dict :type: dict :return: Url :rtype: string """ current_host = OneLogin_Saml2_Utils.get_self_host(request_data) port = '' if OneLogin_Saml2_Utils.is_https(request_data): protocol = 'https' else: protocol = 'http' if 'server_port' in request_data: port_number = request_data['server_port'] port = ':' + port_number if protocol == 'http' and port_number == '80': port = '' elif protocol == 'https' and port_number == '443': port = '' return '%s://%s%s' % (protocol, current_host, port)
@staticmethod
[docs] def get_self_host(request_data): """ Returns the current host. :param request_data: The request as a dict :type: dict :return: The current host :rtype: string """ if 'http_host' in request_data: current_host = request_data['http_host'] elif 'server_name' in request_data: current_host = request_data['server_name'] else: raise Exception('No hostname defined') if ':' in current_host: current_host_data = current_host.split(':') possible_port = current_host_data[-1] try: possible_port = float(possible_port) current_host = current_host_data[0] except ValueError: current_host = ':'.join(current_host_data) return current_host
@staticmethod
[docs] def is_https(request_data): """ Checks if https or http. :param request_data: The request as a dict :type: dict :return: False if https is not active :rtype: boolean """ is_https = 'https' in request_data and request_data['https'] != 'off' is_https = is_https or ('server_port' in request_data and request_data['server_port'] == '443') return is_https
@staticmethod
[docs] def get_self_url_no_query(request_data): """ Returns the URL of the current host + current view. :param request_data: The request as a dict :type: dict :return: The url of current host + current view :rtype: string """ self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data) script_name = request_data['script_name'] if script_name[0] != '/': script_name = '/' + script_name self_url_host += script_name if 'path_info' in request_data: self_url_host += request_data['path_info'] return self_url_host
@staticmethod
[docs] def get_self_url(request_data): """ Returns the URL of the current host + current view + query. :param request_data: The request as a dict :type: dict :return: The url of current host + current view + query :rtype: string """ self_url_host = OneLogin_Saml2_Utils.get_self_url_host(request_data) request_uri = '' if 'request_uri' in request_data: request_uri = request_data['request_uri'] if not request_uri.startswith('/'): match = re.search('^https?://[^/]*(/.*)', request_uri) if match is not None: request_uri = match.groups()[0] return self_url_host + request_uri
@staticmethod
[docs] def generate_unique_id(): """ Generates an unique string (used for example as ID for assertions). :return: A unique string :rtype: string """ return 'ONELOGIN_%s' % sha1(uuid4().hex).hexdigest()
@staticmethod
[docs] def parse_time_to_SAML(time): """ Converts a UNIX timestamp to SAML2 timestamp on the form yyyy-mm-ddThh:mm:ss(\.s+)?Z. :param time: The time we should convert (DateTime). :type: string :return: SAML2 timestamp. :rtype: string """ data = datetime.utcfromtimestamp(float(time)) return data.strftime('%Y-%m-%dT%H:%M:%SZ')
@staticmethod
[docs] def parse_SAML_to_time(timestr): """ Converts a SAML2 timestamp on the form yyyy-mm-ddThh:mm:ss(\.s+)?Z to a UNIX timestamp. The sub-second part is ignored. :param time: The time we should convert (SAML Timestamp). :type: string :return: Converted to a unix timestamp. :rtype: int """ try: data = datetime.strptime(timestr, '%Y-%m-%dT%H:%M:%SZ') except ValueError: data = datetime.strptime(timestr, '%Y-%m-%dT%H:%M:%S.%fZ') return calendar.timegm(data.utctimetuple())
@staticmethod
[docs] def parse_duration(duration, timestamp=None): """ Interprets a ISO8601 duration value relative to a given timestamp. :param duration: The duration, as a string. :type: string :param timestamp: The unix timestamp we should apply the duration to. Optional, default to the current time. :type: string :return: The new timestamp, after the duration is applied. :rtype: int """ assert isinstance(duration, basestring) assert (timestamp is None or isinstance(timestamp, int)) timedelta = duration_parser(duration) if timestamp is None: data = datetime.utcnow() + timedelta else: data = datetime.utcfromtimestamp(timestamp) + timedelta return calendar.timegm(data.utctimetuple())
@staticmethod
[docs] def get_expire_time(cache_duration=None, valid_until=None): """ Compares 2 dates and returns the earliest. :param cache_duration: The duration, as a string. :type: string :param valid_until: The valid until date, as a string or as a timestamp :type: string :return: The expiration time. :rtype: int """ expire_time = None if cache_duration is not None: expire_time = OneLogin_Saml2_Utils.parse_duration(cache_duration) if valid_until is not None: if isinstance(valid_until, int): valid_until_time = valid_until else: valid_until_time = OneLogin_Saml2_Utils.parse_SAML_to_time(valid_until) if expire_time is None or expire_time > valid_until_time: expire_time = valid_until_time if expire_time is not None: return '%d' % expire_time return None
@staticmethod
[docs] def query(dom, query, context=None): """ Extracts nodes that match the query from the Element :param dom: The root of the lxml objet :type: Element :param query: Xpath Expresion :type: string :param context: Context Node :type: DOMElement :returns: The queried nodes :rtype: list """ if context is None: return dom.xpath(query, namespaces=OneLogin_Saml2_Constants.NSMAP) else: return context.xpath(query, namespaces=OneLogin_Saml2_Constants.NSMAP)
@staticmethod
[docs] def delete_local_session(callback=None): """ Deletes the local session. """ if callback is not None: callback()
@staticmethod
[docs] def calculate_x509_fingerprint(x509_cert): """ Calculates the fingerprint of a x509cert. :param x509_cert: x509 cert :type: string :returns: Formated fingerprint :rtype: string """ assert isinstance(x509_cert, basestring) lines = x509_cert.split('\n') data = '' for line in lines: # Remove '\r' from end of line if present. line = line.rstrip() if line == '-----BEGIN CERTIFICATE-----': # Delete junk from before the certificate. data = '' elif line == '-----END CERTIFICATE-----': # Ignore data after the certificate. break elif line == '-----BEGIN PUBLIC KEY-----' or line == '-----BEGIN RSA PRIVATE KEY-----': # This isn't an X509 certificate. return None else: # Append the current line to the certificate data. data += line # "data" now contains the certificate as a base64-encoded string. The # fingerprint of the certificate is the sha1-hash of the certificate. return sha1(base64.b64decode(data)).hexdigest().lower()
@staticmethod
[docs] def format_finger_print(fingerprint): """ Formates a fingerprint. :param fingerprint: fingerprint :type: string :returns: Formated fingerprint :rtype: string """ formated_fingerprint = fingerprint.replace(':', '') return formated_fingerprint.lower()
@staticmethod
[docs] def generate_name_id(value, sp_nq, sp_format, key=None): """ Generates a nameID. :param value: fingerprint :type: string :param sp_nq: SP Name Qualifier :type: string :param sp_format: SP Format :type: string :param key: SP Key to encrypt the nameID :type: string :returns: DOMElement | XMLSec nameID :rtype: string """ doc = Document() name_id = doc.createElement('saml:NameID') name_id.setAttribute('SPNameQualifier', sp_nq) name_id.setAttribute('Format', sp_format) name_id.appendChild(doc.createTextNode(value)) doc.appendChild(name_id) if key is not None: xmlsec.initialize() # Load the private key mngr = xmlsec.KeysMngr() key = OneLogin_Saml2_Utils.format_cert(key, heads=False) file_key = OneLogin_Saml2_Utils.write_temp_file(key) key_data = xmlsec.Key.load(file_key.name, xmlsec.KeyDataFormatPem, None) key_data.name = key_name = basename(file_key.name) mngr.addKey(key_data) file_key.close() # Prepare for encryption enc_data = EncData(xmlsec.TransformAes128Cbc, type=xmlsec.TypeEncElement) enc_data.ensureCipherValue() key_info = enc_data.ensureKeyInfo() enc_key = key_info.addEncryptedKey(xmlsec.TransformRsaPkcs1) enc_key.ensureCipherValue() enc_key_info = enc_key.ensureKeyInfo() enc_key_info.addKeyName(key_name) # Encrypt! enc_ctx = xmlsec.EncCtx(mngr) enc_ctx.enc_key = xmlsec.Key.generate(xmlsec.KeyDataAes, 128, xmlsec.KeyDataTypeSession) ed = enc_ctx.encryptXml(enc_data, doc.getroot()) # Build XML with encrypted data newdoc = Document() encrypted_id = newdoc.createElement('saml:EncryptedID') newdoc.appendChild(encrypted_id) encrypted_id.appendChild(encrypted_id.ownerDocument.importNode(ed, True)) return newdoc.saveXML(encrypted_id) else: return doc.saveXML(name_id)
@staticmethod
[docs] def get_status(dom): """ Gets Status from a Response. :param dom: The Response as XML :type: Document :returns: The Status, an array with the code and a message. :rtype: dict """ status = {} status_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status') if len(status_entry) == 0: raise Exception('Missing Status on response') code_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status/samlp:StatusCode', status_entry[0]) if len(code_entry) == 0: raise Exception('Missing Status Code on response') code = code_entry[0].values()[0] status['code'] = code message_entry = OneLogin_Saml2_Utils.query(dom, '/samlp:Response/samlp:Status/samlp:StatusMessage', status_entry[0]) if len(message_entry) == 0: status['msg'] = '' else: status['msg'] = message_entry[0].text return status
@staticmethod
[docs] def decrypt_element(encrypted_data, enc_ctx): """ Decrypts an encrypted element. :param encrypted_data: The encrypted data. :type: DOMElement :param enc_ctx: The encryption context. :type: Encryption Context :returns: The decrypted element. :rtype: DOMElement """ if isinstance(encrypted_data, Element): # Minidom element encrypted_data = etree.fromstring(encrypted_data.toxml()) decrypted = enc_ctx.decrypt(encrypted_data) if isinstance(decrypted, ElementBase): # lxml element, decrypted xml data return tostring(decrypted.getroottree()) else: # decrypted binary data return decrypted
@staticmethod
[docs] def write_temp_file(content): """ Writes some content into a temporary file and returns it. :param content: The file content :type: string :returns: The temporary file :rtype: file-like object """ f = NamedTemporaryFile(delete=True) f.file.write(content) f.file.flush() return f
@staticmethod
[docs] def add_sign(xml, key, cert): """ Adds signature key and senders certificate to an element (Message or Assertion). :param xml: The element we should sign :type: string | Document :param key: The private key :type: string :param cert: The public :type: string """ if isinstance(xml, Document): dom = xml else: if xml == '': raise Exception('Empty string supplied as input') try: dom = parseString(xml) except Exception: raise Exception('Error parsing xml string') xmlsec.initialize() # TODO the key and cert could be file descriptors instead # Load the private key. file_key = OneLogin_Saml2_Utils.write_temp_file(key) sign_key = xmlsec.Key.load(file_key.name, xmlsec.KeyDataFormatPem, None) file_key.close() # Add the certificate to the signature. file_cert = OneLogin_Saml2_Utils.write_temp_file(cert) sign_key.loadCert(file_cert.name, xmlsec.KeyDataFormatPem) file_cert.close() # Get the EntityDescriptor node we should sign. root_node = dom.firstChild # Sign the metadata with our private key. signature = Signature(xmlsec.TransformExclC14N, xmlsec.TransformRsaSha1) ref = signature.addReference(xmlsec.TransformSha1) ref.addTransform(xmlsec.TransformEnveloped) key_info = signature.ensureKeyInfo() key_info.addX509Data() dsig_ctx = xmlsec.DSigCtx() dsig_ctx.signKey = sign_key dsig_ctx.sign(signature) signature = tostring(signature).replace('ns0:', 'ds:').replace(':ns0', ':ds') signature = parseString(signature).firstChild insert_before = root_node.getElementsByTagName('saml:Issuer') if len(insert_before) > 0: insert_before = insert_before[0].nextSibling else: insert_before = root_node.firstChild.nextSibling.nextSibling root_node.insertBefore(signature, insert_before) return dom.toxml()
@staticmethod
[docs] def validate_sign(xml, cert=None, fingerprint=None): """ Validates a signature (Message or Assertion). :param xml: The element we should validate :type: string | Document :param cert: The pubic cert :type: string :param fingerprint: The fingerprint of the public cert :type: string """ if isinstance(xml, Document): dom = etree.fromstring(xml.toxml()) else: if xml == '': raise Exception('Empty string supplied as input') try: dom = etree.fromstring(xml) except Exception: raise Exception('Error parsing xml string') xmlsec.initialize() # Find signature in the dom signature_node = OneLogin_Saml2_Utils.query(dom, 'ds:Signature')[0] # Prepare context and load cert into it dsig_ctx = xmlsec.DSigCtx() sign_cert = X509.load_cert_string(str(cert), X509.FORMAT_PEM) pub_key = sign_cert.get_pubkey().get_rsa() sign_key = xmlsec.Key.loadMemory(pub_key.as_pem(cipher=None), xmlsec.KeyDataFormatPem) dsig_ctx.signKey = sign_key # Verify signature dsig_ctx.verify(signature_node)