# $Id: hpke_test.py $ # $Date: 2023-10-24 11:479Z $ # ****************************** LICENSE *********************************** # Copyright (C) 2023 David Ireland, DI Management Services Pty Limited. # All rights reserved. <https://di-mgt.com.au> <https://cryptosys.net> # The code in this module is licensed under the terms of the MIT license. # @license MIT # For a copy, see <http://opensource.org/licenses/MIT> # ************************************************************************** """ This is a "proof-of-concept" program to show that cryptosyspki v22.0 has all the features necessary to carry out the required computations for RFC9180 "Hybrid Public Key Encryption". It reads a cut-down version of the test-vectors.json file and computes and verifies the values for mode 0. EC keys (skX, pkX) are stored in cryptosyspki ephemeral "internal" key string form (treat these as "blobs" valid only for this session). The serialized form is in a byte array; use ``Cnv.tohex(key)`` to print. There is some confusion over clamping X25519 and X448 serialized private keys - the RFC says they MUST be masked/clamped but the test vectors do not do that. For this version, cryptosyspki does NOT clamp the serialized private keys. See https://www.rfc-editor.org/errata_search.php?rfc=9180 https://www.rfc-editor.org/errata/eid7121 Requires installation of CryptoSys PKI v22.0 or later available from https://cryptosys.net/pki/ and installation of the Python package ``cryptosyspki`` available from PyPi https://pypi.org/project/cryptosyspki/ ``pip install cryptosyspki`` """ from cryptosyspki import * import json PKI_MIN_VERSION = 220000 infile = "test-vectors-1.json" # LOOKUP TABLES curveNames = { 0x10: Hpke.CurveName.P_256, 0x11: Hpke.CurveName.P_384, 0x12: Hpke.CurveName.P_521, 0x20: Hpke.CurveName.X25519, 0x21: Hpke.CurveName.X448, } aeadAlgs = { 0x0001: Hpke.AeadAlg.AES_128_GCM, 0x0002: Hpke.AeadAlg.AES_256_GCM, 0x0003: Hpke.AeadAlg.CHACHA20_POLY1305, 0xFFFF: None, # TODO: deal with 0xFFFF Export-only } # KEM dictionary RFC9180 Table 2 kems = { 0x10: {"name": "DHKEM(P-256, HKDF-SHA256)", "Nsecret": 32, "Nenc": 65, "Npk": 65, "Nsk": 32}, 0x11: {"name": "DHKEM(P-384, HKDF-SHA384)", "Nsecret": 48, "Nenc": 97, "Npk": 97, "Nsk": 48}, 0x12: {"name": "DHKEM(P-521, HKDF-SHA512)", "Nsecret": 64, "Nenc": 133, "Npk": 133, "Nsk": 66}, 0x20: {"name": "DHKEM(X25519, HKDF-SHA256)", "Nsecret": 32, "Nenc": 32, "Npk": 32, "Nsk": 32}, 0x21: {"name": "DHKEM(X448, HKDF-SHA512)", "Nsecret": 64, "Nenc": 56, "Npk": 56, "Nsk": 56}, } # KDF dictionary RFC9180 Table 3 kdfs = { 0x0001: {"name": "HDKF-SHA256", "Nh": 32}, 0x0002: {"name": "HDKF-SHA384", "Nh": 48}, 0x0003: {"name": "HDKF-SHA512", "Nh": 64}, } # AEAD dictionary RFC9180 Table 5 aeads = { 0x0001: {"name": "AES-128-GCM", "Nk": 16, "Nn": 12, "Nt": 16}, 0x0002: {"name": "AES-256-GCM", "Nk": 32, "Nn": 12, "Nt": 16}, 0x0003: {"name": "ChaCha20Poly1305", "Nk": 32, "Nn": 12, "Nt": 16}, 0xFFFF: {"name": "Export-only", "Nk": None, "Nn": None, "Nt": None}, } # CONSTANTS mode_base = 0x00 mode_psk = 0x01 mode_auth = 0x02 mode_auth_psk = 0x03 # Utility functions def I2OSP(x, w): """ Convert non-negative integer x to a w-length, big-endian byte string, as described in [RFC8017]. """ if w < 0: raise ValueError('length w must be a nonegative integer') s = bytearray([0]) * w i = w - 1 while x > 0 and i >= 0: s[i] = x & 0xFF x >>= 8 i -= 1 return s def OS2IP(x): """Convert byte string x to a non-negative integer, as described in [RFC8017], assuming big-endian byte order.""" xlen = len(x) n = 0 for i in range(xlen): # n += int(x[i]) * pow(256, xlen-i-1) if (x[i]): n += int(x[i]) * (1 << (8 * (xlen - i - 1))) return n def xor_bytes(b1, b2): x = [a ^ b for a, b in zip(b1, b2)] return bytes(x) def ComputeNonce(seq, base_nonce, Nn): seq_bytes = I2OSP(seq, Nn) return xor_bytes(base_nonce, seq_bytes) def wrap(text, width=70): """Simple textwrap to display string.""" return "\n".join([text[i:i + width] for i in range(0, len(text) - 1, width)]) def wrap_hex(b, width=70): """Encode byte array in wrapped hex-encoded string.""" return wrap(Cnv.tohex(b).lower(), width) def VerifyPSKInputs(mode, psk, psk_id): """Copied from RFC9180""" default_psk = b"" default_psk_id = b"" got_psk = (psk != default_psk) got_psk_id = (psk_id != default_psk_id) if got_psk != got_psk_id: raise Exception("Inconsistent PSK inputs") if got_psk and (mode in [mode_base, mode_auth]): raise Exception("PSK input provided when not needed") if (not got_psk) and (mode in [mode_psk, mode_auth_psk]): raise Exception("Missing required PSK input") class CipherSuite(): def __init__(self, mode, kem_id, kdf_id, aead_id): self.mode = mode self.kem_id = kem_id self.kdf_id = kdf_id self.aead_id = aead_id self.curveName = curveNames[kem_id] self.aeadAlg = aeadAlgs[aead_id] # TODO 0xFFFF def name(self): # Make string of form "DHKEM(X25519, HKDF-SHA256), HKDF-SHA256, AES-128-GCM" s = f"{kems[self.kem_id]['name']}, {kdfs[self.kdf_id]['name']}, {aeads[self.aead_id]['name']}" return s def DeriveKeyPair(self, ikm): # sk and pk are "internal" representations valid only for this session # In this implementation, the curve group defines all parameters to derive the key sk = Hpke.derive_private_key(ikm, self.curveName) pk = Ecc.publickey_from_private(sk) # return Ecc.query_key(sk, "privateKey"), Ecc.query_key(pk, "publicKey") return sk, pk @staticmethod def SerializePublicKey(pk): """ Serialize the public key. :param pk: public key in internal string form :return: serialized key in bytes """ keyhex = Ecc.query_key(pk, "publicKey") return Cnv.fromhex(keyhex) @staticmethod def SerializePrivateKey(sk): """ Serialize the private key. :param sk: private key in internal string form :return: serialized key in bytes """ keyhex = Ecc.query_key(sk, "privateKey") return Cnv.fromhex(keyhex) def DeserializePublicKey(self, pk): """ Deserialize the public key. :param pk: public key in bytes form :return: Public key in internal string form """ intkey = Ecc.read_key_by_curve(Cnv.tohex(pk), self.curveName, ispublic=True) return intkey def DeserializePrivateKey(self, sk): """ Deserialize the private key. :param sk: private key in bytes form :return: Private key in internal string form """ intkey = Ecc.read_key_by_curve(Cnv.tohex(sk), self.curveName, ispublic=False) return intkey @staticmethod def verifyPrivateKeysEqual(skhex, sk): # CAUTION: the sk returned by DerivePrivateKey is *not* clamped. # This is to match the test_vectors which are also not clamped # Ref: https://www.rfc-editor.org/errata_search.php?rfc=9180 # https://www.rfc-editor.org/errata/eid7121 # NOTE: to be changed if the test vectors are ever updated. keyhex = skhex # if self.curveName == "X25519" or self.curveName == "X448": # intkey = Ecc.read_key_by_curve(skhex, self.curveName, ispublic=False) # keyhex = Ecc.query_key(intkey, "privateKey") # else: # keyhex = skhex assert keyhex.lower() == Cnv.tohex(sk).lower() @staticmethod def verifyPublicKeysEqual(pkhex, pk): # pk is in byte form keyhex = pkhex assert keyhex.lower() == Cnv.tohex(pk).lower() @staticmethod def validatePublicKey(pk): """Validate public key.""" r = Ecc.query_key(pk, "isValid") assert (bool(r)) def DH(self, skX, pkY): # Keys are in internal key form # First, validate the public key provided self.validatePublicKey(pkY) # Compute DH shared secret dh = Ecc.dh_shared_secret(skX, pkY) return dh def ExtractAndExpand(self, dh, kem_context): eae_prk = Hpke.labeled_extract(b'', "eae_prk", dh, self.curveName) Nsecret = kems[self.kem_id]['Nsecret'] shared_secret = Hpke.labeled_expand(Nsecret, eae_prk, "shared_secret", kem_context, self.curveName) return shared_secret def Decap(self, enc, skR): pkE = self.DeserializePublicKey(enc) dh = self.DH(skR, pkE) pkR = Ecc.publickey_from_private(skR) pkRm = self.SerializePublicKey(pkR) kem_context = enc + pkRm # In bytes shared_secret = self.ExtractAndExpand(dh, kem_context) return shared_secret def verifyTestVector(tv): print("\nTestVector:") suite = CipherSuite(tv['mode'], tv['kem_id'], tv['kdf_id'], tv['aead_id']) print(suite.name()) print("mode:", suite.mode) print("kem_id:", suite.kem_id) print("kdf_id:", suite.kdf_id) print("aead_id:", suite.aead_id) print("info:", tv['info']) # Derive keys from ikm... print("ikmE:", wrap(tv['ikmE'])) skE, pkE = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmE'])) pkEm = suite.SerializePublicKey(pkE) skEm = suite.SerializePrivateKey(skE) print("pkEm:", wrap_hex(pkEm), "skEm:", wrap_hex(skEm), sep='\n') suite.verifyPublicKeysEqual(tv['pkEm'], pkEm) suite.verifyPrivateKeysEqual(tv['skEm'], skEm) print("ikmR:", wrap(tv['ikmR'])) skR, pkR = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmR'])) pkRm = suite.SerializePublicKey(pkR) skRm = suite.SerializePrivateKey(skR) print("pkRm:", wrap_hex(pkRm), "skRm:", wrap_hex(skRm), sep='\n') suite.verifyPublicKeysEqual(tv['pkRm'], pkRm) suite.verifyPrivateKeysEqual(tv['skRm'], skRm) # Create Context intermediate values and outputs # Includes KEM outputs enc and shared_secret used to create the context # shared_secret, enc = Encap(pkR) # Compute DH shared secret both ways and check they are equal dh = suite.DH(skE, pkR) print("dh:", Cnv.tohex(dh)) dhR = suite.DH(skR, pkE) assert (Cnv.tohex(dh) == Cnv.tohex(dhR)) # enc = SerializePublicKey(pkE) enc = suite.SerializePublicKey(pkE) print("enc:", wrap_hex(enc), sep='\n') assert (Cnv.tohex(enc).lower() == tv['enc']) pkRm = suite.SerializePublicKey(pkR) # kem_context = concat(enc, pkRm) kem_context = enc + pkRm # NB in byte form # shared_secret = ExtractAndExpand(dh, kem_context) shared_secret = suite.ExtractAndExpand(dh, kem_context) print("shared_secret:", wrap_hex(shared_secret), sep='\n') assert (Cnv.tohex(shared_secret).lower() == tv['shared_secret']) # Check Decap works shared_secret_R = suite.Decap(enc, skR) assert (Cnv.tohex(shared_secret_R).lower() == tv['shared_secret']) # Derive the key_schedule_context # NB from now on, make sure to include aead when calling Hpke.labeled_extract ''' psk_id_hash = LabeledExtract("", "psk_id_hash", psk_id) info_hash = LabeledExtract("", "info_hash", info) key_schedule_context = concat(mode, psk_id_hash, info_hash) ''' info = Cnv.fromhex(tv['info']) psk = b'' psk_id = b'' # We only do mode 0 here... if (suite.mode != 0): print("NOT SUPPORTED: Mode " + str(suite.mode) + " is not supported here, just mode 0") return # So just return and try the next one VerifyPSKInputs(suite.mode, psk, psk_id) # <== Not necessary here, but do it anyway psk_id_hash = Hpke.labeled_extract(b'', "psk_id_hash", psk_id, suite.curveName, suite.aeadAlg) info_hash = Hpke.labeled_extract(b'', "info_hash", info, suite.curveName, suite.aeadAlg) key_schedule_context = I2OSP(suite.mode, 1) + psk_id_hash + info_hash print("key_schedule_context:", wrap_hex(key_schedule_context), sep='\n') assert (Cnv.tohex(key_schedule_context).lower() == tv['key_schedule_context']) ''' secret = LabeledExtract(shared_secret, "secret", psk) key = LabeledExpand(secret, "key", key_schedule_context, Nk) base_nonce = LabeledExpand(secret, "base_nonce", key_schedule_context, Nn) exporter_secret = LabeledExpand(secret, "exp", key_schedule_context, Nh) return Context<ROLE>(key, base_nonce, 0, exporter_secret) ''' secret = Hpke.labeled_extract(shared_secret, "secret", psk, suite.curveName, suite.aeadAlg) print("secret:", wrap_hex(secret), sep='\n') assert (Cnv.tohex(secret).lower() == tv['secret']) key = Hpke.labeled_expand(aeads[suite.aead_id]["Nk"], secret, "key", key_schedule_context, suite.curveName, suite.aeadAlg) print("key:", wrap_hex(key), sep='\n') assert (Cnv.tohex(key).lower() == tv['key']) base_nonce = Hpke.labeled_expand(aeads[suite.aead_id]["Nn"], secret, "base_nonce", key_schedule_context, suite.curveName, suite.aeadAlg) print("base_nonce:", wrap_hex(base_nonce), sep='\n') assert (Cnv.tohex(base_nonce).lower() == tv['base_nonce']) exporter_secret = Hpke.labeled_expand(kdfs[suite.kdf_id]["Nh"], secret, "exp", key_schedule_context, suite.curveName, suite.aeadAlg) print("exporter_secret:", wrap_hex(exporter_secret), sep='\n') assert (Cnv.tohex(exporter_secret).lower() == tv['exporter_secret']) # Do encryptions seq_num = 0 Nn = aeads[suite.aead_id]["Nn"] aeadAlg = aeadAlgs[suite.aead_id] encryptions = tv['encryptions'] for e in encryptions: print("sequence number:", seq_num) pt = Cnv.fromhex(e['pt']) print("pt:", Cnv.tohex(pt)) aad = Cnv.fromhex(e['aad']) print("aad:", Cnv.tohex(aad)) nonce = ComputeNonce(seq_num, base_nonce, Nn) print("nonce:", Cnv.tohex(nonce)) # Sender seals (encrypts) the plaintext ct = Cipher.encrypt_aead(pt, key, nonce, aeadAlg, aad) print("ct:", wrap_hex(ct)) assert Cnv.tohex(ct).lower() == e['ct'] # Recipient opens (decrypts) the ciphertext dt = Cipher.decrypt_aead(ct, key, nonce, aeadAlg, aad) assert Cnv.tohex(dt).lower() == e['pt'] # Finally, increment sequence number seq_num += 1 def main(): # Check min version ver = Gen.version() print("cryptosyspki version =", ver) assert ver >= PKI_MIN_VERSION print("FILE:", infile) ntvs = 0 # Read in JSON test vectors with open(infile) as f: tvs = json.load(f) print("len(test-vectors)=", len(tvs)) # Verify each test vector for tv in tvs: verifyTestVector(tv) ntvs += 1 # verifyTestVector(tvs[0]) # verifyTestVector(tvs[1]) print(f"\nALL DONE. Processed {ntvs} test vectors.") if __name__ == "__main__": main()