# -*- coding:utf-8 -*- # # Copyright 2014 Hewlett-Packard Development Company, L.P. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import unittest import netaddr from pyasn1.codec.der import encoder from pyasn1.type import univ from anchor.asn1 import rfc5280 from anchor.X509 import errors from anchor.X509 import extension class TestExtensionBase(unittest.TestCase): def test_no_spec(self): with self.assertRaises(errors.X509Error): extension.X509Extension() def test_invalid_asn(self): with self.assertRaises(errors.X509Error): extension.X509Extension("foobar") def test_unknown_extension_str(self): asn1 = rfc5280.Extension() asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4') asn1['critical'] = False asn1['extnValue'] = "foobar" ext = extension.X509Extension(asn1) self.assertEqual("1.2.3.4: <unknown>", str(ext)) def test_construct(self): asn1 = rfc5280.Extension() asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4') asn1['critical'] = False asn1['extnValue'] = "foobar" ext = extension.construct_extension(asn1) self.assertIsInstance(ext, extension.X509Extension) def test_construct_invalid_type(self): with self.assertRaises(errors.X509Error): extension.construct_extension("foobar") def test_critical(self): asn1 = rfc5280.Extension() asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4') asn1['critical'] = False asn1['extnValue'] = "foobar" ext = extension.construct_extension(asn1) self.assertFalse(ext.get_critical()) ext.set_critical(True) self.assertTrue(ext.get_critical()) def test_serialise(self): asn1 = rfc5280.Extension() asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4') asn1['critical'] = False asn1['extnValue'] = "foobar" ext = extension.construct_extension(asn1) self.assertEqual(ext.as_der(), encoder.encode(asn1)) def test_broken_set_value(self): class SomeExt(extension.X509Extension): spec = rfc5280.Extension _oid = univ.ObjectIdentifier('1.2.3.4') @classmethod def _get_default_value(cls): return 1234 with self.assertRaisesRegexp(errors.X509Error, 'incorrect type'): SomeExt() class TestBasicConstraints(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionBasicConstraints() def test_str(self): self.assertEqual(str(self.ext), "basicConstraints: CA: FALSE, pathLen: None") def test_ca(self): self.ext.set_ca(True) self.assertTrue(self.ext.get_ca()) self.ext.set_ca(False) self.assertFalse(self.ext.get_ca()) def test_pathlen(self): self.ext.set_path_len_constraint(1) self.assertEqual(1, self.ext.get_path_len_constraint()) class TestKeyUsage(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionKeyUsage() def test_usage_set(self): self.ext.set_usage('digitalSignature', True) self.ext.set_usage('keyAgreement', False) self.assertTrue(self.ext.get_usage('digitalSignature')) self.assertFalse(self.ext.get_usage('keyAgreement')) def test_usage_reset(self): self.ext.set_usage('digitalSignature', True) self.ext.set_usage('digitalSignature', False) self.assertFalse(self.ext.get_usage('digitalSignature')) def test_usage_unset(self): self.assertFalse(self.ext.get_usage('keyAgreement')) def test_get_all_usage(self): self.ext.set_usage('digitalSignature', True) self.ext.set_usage('keyAgreement', False) self.ext.set_usage('keyEncipherment', True) self.assertEqual(set(['digitalSignature', 'keyEncipherment']), set(self.ext.get_all_usages())) def test_str(self): self.ext.set_usage('digitalSignature', True) self.assertEqual("keyUsage: digitalSignature", str(self.ext)) class TestSubjectAltName(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionSubjectAltName() self.domain = 'example.com' self.ip = netaddr.IPAddress('1.2.3.4') self.ip6 = netaddr.IPAddress('::1') def test_dns_ids(self): self.ext.add_dns_id(self.domain) self.ext.add_ip(self.ip) self.assertEqual([self.domain], self.ext.get_dns_ids()) def test_ips(self): self.ext.add_dns_id(self.domain) self.ext.add_ip(self.ip) self.assertEqual([self.ip], self.ext.get_ips()) def test_ipv6(self): self.ext.add_ip(self.ip6) self.assertEqual([self.ip6], self.ext.get_ips()) def test_add_ip_invalid(self): with self.assertRaises(errors.X509Error): self.ext.add_ip("abcdef") def test_str(self): self.ext.add_dns_id(self.domain) self.ext.add_ip(self.ip) self.assertEqual("subjectAltName: DNS:example.com, IP:1.2.3.4", str(self.ext)) class TestNameConstraints(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionNameConstraints() def test_length(self): self.assertEqual(0, self.ext.get_permitted_length()) self.assertEqual(0, self.ext.get_excluded_length()) def test_add(self): test_name = 'example.com' test_type = 'dNSName' self.assertEqual(0, self.ext.get_permitted_length()) self.assertEqual(0, self.ext.get_excluded_length()) self.ext.add_permitted(test_type, test_name) self.assertEqual(1, self.ext.get_permitted_length()) self.assertEqual(0, self.ext.get_excluded_length()) self.ext.add_excluded(test_type, test_name) self.assertEqual(1, self.ext.get_permitted_length()) self.assertEqual(1, self.ext.get_excluded_length()) def test_excluded(self): self.ext.add_excluded('dNSName', 'example.com') self.assertEqual(self.ext.get_excluded_range(0), (0, None)) self.assertEqual(self.ext.get_excluded_name(0), ('dNSName', b'example.com')) def test_permitted(self): self.ext.add_permitted('dNSName', 'example.com') self.assertEqual(self.ext.get_permitted_range(0), (0, None)) self.assertEqual(self.ext.get_permitted_name(0), ('dNSName', b'example.com')) class TestExtendedKeyUsage(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionExtendedKeyUsage() def test_get_all(self): self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.ext.set_usage(rfc5280.id_kp_codeSigning, True) usages = self.ext.get_all_usages() self.assertEqual(2, len(usages)) self.assertIn(rfc5280.id_kp_clientAuth, usages) def test_get_one(self): self.assertFalse(self.ext.get_usage(rfc5280.id_kp_clientAuth)) self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.assertTrue(self.ext.get_usage(rfc5280.id_kp_clientAuth)) def test_set(self): self.assertEqual(0, len(self.ext.get_all_usages())) self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.assertEqual(1, len(self.ext.get_all_usages())) self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.assertEqual(1, len(self.ext.get_all_usages())) self.ext.set_usage(rfc5280.id_kp_codeSigning, True) self.assertEqual(2, len(self.ext.get_all_usages())) def test_unset(self): self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.ext.set_usage(rfc5280.id_kp_clientAuth, False) self.assertEqual(0, len(self.ext.get_all_usages())) self.ext.set_usage(rfc5280.id_kp_clientAuth, False) self.assertEqual(0, len(self.ext.get_all_usages())) def test_str(self): self.ext.set_usage(rfc5280.id_kp_clientAuth, True) self.ext.set_usage(rfc5280.id_kp_codeSigning, True) self.assertEqual( "extKeyUsage: TLS Web Client Authentication, Code Signing", str(self.ext)) def test_invalid_usage(self): self.assertRaises(ValueError, self.ext.get_usage, univ.ObjectIdentifier('1.2.3.4')) self.assertRaises(ValueError, self.ext.set_usage, True, univ.ObjectIdentifier('1.2.3.4')) class TestAuthorityKeyId(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionAuthorityKeyId() def test_key_id(self): key_id = b"12345678" self.ext.set_key_id(key_id) self.assertEqual(key_id, self.ext.get_key_id()) def test_name_serial(self): s = 12345678 self.ext.set_serial(s) self.assertEqual(s, self.ext.get_serial()) class TestSubjectKeyId(unittest.TestCase): def setUp(self): self.ext = extension.X509ExtensionSubjectKeyId() def test_key_id(self): key_id = b"12345678" self.ext.set_key_id(key_id) self.assertEqual(key_id, self.ext.get_key_id())