Source code for M2Crypto.SSL.Checker

"""
SSL peer certificate checking routines

Copyright (c) 2004-2007 Open Source Applications Foundation.
All rights reserved.

Copyright 2008 Heikki Toivonen. All rights reserved.
"""

__all__ = [
    "SSLVerificationError",
    "NoCertificate",
    "WrongCertificate",
    "WrongHost",
    "Checker",
]

import re
import socket

from M2Crypto import X509, m2  # noqa
from typing import Optional, Union  # noqa

try:
    from re import Pattern
except ImportError:
    from typing import Pattern


[docs] class SSLVerificationError(Exception): pass
[docs] class NoCertificate(SSLVerificationError): pass
[docs] class WrongCertificate(SSLVerificationError): pass
[docs] class WrongHost(SSLVerificationError): def __init__( self, expectedHost: str, actualHost: Union[str, bytes], fieldName: str = "commonName", ) -> None: """ This exception will be raised if the certificate returned by the peer was issued for a different host than we tried to connect to. This could be due to a server misconfiguration or an active attack. :param expectedHost: The name of the host we expected to find in the certificate. :param actualHost: The name of the host we actually found in the certificate. :param fieldName: The field name where we noticed the error. This should be either 'commonName' or 'subjectAltName'. """ if fieldName not in ("commonName", "subjectAltName"): raise ValueError( "Unknown fieldName, should be either commonName " + "or subjectAltName" ) SSLVerificationError.__init__(self) self.expectedHost = expectedHost self.actualHost = actualHost self.fieldName = fieldName def __str__(self) -> str: actual = ( self.actualHost.decode() if isinstance(self.actualHost, bytes) else self.actualHost ) return "Peer certificate %s does not match host, expected %s, got %s" % ( self.fieldName, self.expectedHost, actual, )
[docs] class Checker: # COMPATIBILITY: re.Pattern is available only from Python 3.7+ numericIpMatch: Pattern[str] = re.compile(r"^[0-9]+(\.[0-9]+)*$") def __init__( self, host: Optional[str] = None, peerCertHash: Optional[bytes] = None, peerCertDigest: str = "sha256", ) -> None: self.host = host self.fingerprint = peerCertHash self.digest: str = peerCertDigest def __call__( self, peerCert: Optional[X509.X509], host: Optional[str] = None ) -> bool: if peerCert is None: raise NoCertificate("peer did not return certificate") if host is not None: self.host = host if self.fingerprint: if self.digest not in ("sha256"): raise ValueError('unsupported digest "%s"' % self.digest) if self.digest == "sha256": expected_len = 64 else: raise ValueError("Unexpected digest {0}".format(self.digest)) if len(self.fingerprint) != expected_len: raise WrongCertificate( ( "peer certificate fingerprint length does not match\n" + "fingerprint: {0}\nexpected = {1}\n" + "observed = {2}" ).format( self.fingerprint, expected_len, len(self.fingerprint), ) ) expected_fingerprint = ( self.fingerprint.decode() if isinstance(self.fingerprint, bytes) else self.fingerprint ) observed_fingerprint = peerCert.get_fingerprint(md=self.digest) if observed_fingerprint != expected_fingerprint: raise WrongCertificate( ( "peer certificate fingerprint does not match\n" + "expected = {0},\n" + "observed = {1}" ).format(expected_fingerprint, observed_fingerprint) ) if self.host: hostValidationPassed = False self.useSubjectAltNameOnly = False # subjectAltName=DNS:somehost[, ...]* try: subjectAltName = peerCert.get_ext("subjectAltName").get_value() if self._splitSubjectAltName(self.host, subjectAltName): hostValidationPassed = True elif self.useSubjectAltNameOnly: raise WrongHost( expectedHost=self.host, actualHost=subjectAltName, fieldName="subjectAltName", ) except LookupError: pass # commonName=somehost[, ...]* if not hostValidationPassed: hasCommonName = False commonNames = "" for entry in peerCert.get_subject().get_entries_by_nid( m2.NID_commonName ): hasCommonName = True commonName = entry.get_data().as_text() if not commonNames: commonNames = commonName else: commonNames += "," + commonName if self._match(self.host, commonName): hostValidationPassed = True break if not hasCommonName: raise WrongCertificate("no commonName in peer certificate") if not hostValidationPassed: raise WrongHost( expectedHost=self.host, actualHost=commonNames, fieldName="commonName", ) return True def _splitSubjectAltName( self, host: Union[str, bytes], subjectAltName: Union[str, bytes], ) -> bool: """ >>> check = Checker() >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:my.example.com') True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:*.example.com') True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:m*.example.com') True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:m*ample.com') False >>> check.useSubjectAltNameOnly True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:m*ample.com, othername:<unsupported>') False >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:m*ample.com, DNS:my.example.org') False >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:m*ample.com, DNS:my.example.com') True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='DNS:my.example.com, DNS:my.example.org') True >>> check.useSubjectAltNameOnly True >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='') False >>> check._splitSubjectAltName(host='my.example.com', ... subjectAltName='othername:<unsupported>') False >>> check.useSubjectAltNameOnly False """ host_str = host.decode() if isinstance(host, bytes) else host san_str = ( subjectAltName.decode() if isinstance(subjectAltName, bytes) else subjectAltName ) self.useSubjectAltNameOnly = False for certHost in san_str.split(","): certHost = certHost.lower().strip() if certHost[:4] == "dns:": self.useSubjectAltNameOnly = True if self._match(host_str, certHost[4:]): return True elif certHost[:11] == "ip address:": self.useSubjectAltNameOnly = True if self._matchIPAddress(host_str, certHost[11:]): return True return False def _match(self, host: str, certHost: str) -> bool: """ >>> check = Checker() >>> check._match(host='my.example.com', certHost='my.example.com') True >>> check._match(host='my.example.com', certHost='*.example.com') True >>> check._match(host='my.example.com', certHost='m*.example.com') True >>> check._match(host='my.example.com', certHost='m*.EXAMPLE.com') True >>> check._match(host='my.example.com', certHost='m*ample.com') False >>> check._match(host='my.example.com', certHost='*.*.com') False >>> check._match(host='1.2.3.4', certHost='1.2.3.4') True >>> check._match(host='1.2.3.4', certHost='*.2.3.4') False >>> check._match(host='1234', certHost='1234') True """ # XXX See RFC 2818 and 3280 for matching rules, this is may not # XXX yet be complete. host = host.lower() certHost = certHost.lower() if host == certHost: return True if certHost.count("*") > 1: # Not sure about this, but being conservative return False if self.numericIpMatch.match(host) or self.numericIpMatch.match( certHost.replace("*", "") ): # Not sure if * allowed in numeric IP, but think not. return False if certHost.find("\\") > -1: # Not sure about this, maybe some encoding might have these. # But being conservative for now, because regex below relies # on this. return False # Massage certHost so that it can be used in regex certHost = certHost.replace(".", "\\.") certHost = certHost.replace("*", "[^\\.]*") if re.compile("^%s$" % certHost).match(host): return True return False def _matchIPAddress( self, host: Union[str, bytes], certHost: Union[str, bytes] ) -> bool: """ >>> check = Checker() >>> check._matchIPAddress(host='my.example.com', ... certHost='my.example.com') False >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.4') True >>> check._matchIPAddress(host='1.2.3.4', certHost='*.2.3.4') False >>> check._matchIPAddress(host='1.2.3.4', certHost='1.2.3.40') False >>> check._matchIPAddress(host='::1', certHost='::1') True >>> check._matchIPAddress(host='::1', certHost='0:0:0:0:0:0:0:1') True >>> check._matchIPAddress(host='::1', certHost='::2') False """ try: canonical = socket.getaddrinfo( host, 0, 0, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST, ) certCanonical = socket.getaddrinfo( certHost, 0, 0, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST, ) except: return False return canonical == certCanonical
if __name__ == "__main__": import doctest doctest.testmod()