# $Id: hpke_test.py $
# $Date: 2023-10-24 11:479Z $
# ****************************** LICENSE ***********************************
# Copyright (C) 2023 David Ireland, DI Management Services Pty Limited.
# All rights reserved. <https://di-mgt.com.au> <https://cryptosys.net>
# The code in this module is licensed under the terms of the MIT license.
# @license MIT
# For a copy, see <http://opensource.org/licenses/MIT>
# **************************************************************************
"""
This is a "proof-of-concept" program to show that cryptosyspki v22.0 has all the features necessary to carry
out the required computations for RFC9180 "Hybrid Public Key Encryption".
It reads a cut-down version of the test-vectors.json file and computes and verifies the values for mode 0.
EC keys (skX, pkX) are stored in cryptosyspki ephemeral "internal" key string form (treat these as "blobs" valid only for this session).
The serialized form is in a byte array; use ``Cnv.tohex(key)`` to print.
There is some confusion over clamping X25519 and X448 serialized private keys - the RFC says they MUST be masked/clamped
but the test vectors do not do that. For this version, cryptosyspki does NOT clamp the serialized private keys.
See https://www.rfc-editor.org/errata_search.php?rfc=9180 https://www.rfc-editor.org/errata/eid7121
Requires installation of CryptoSys PKI v22.0 or later available from
https://cryptosys.net/pki/
and installation of the Python package ``cryptosyspki`` available from PyPi
https://pypi.org/project/cryptosyspki/
``pip install cryptosyspki``
"""
from cryptosyspki import *
import json
PKI_MIN_VERSION = 220000
infile = "test-vectors-1.json"
# LOOKUP TABLES
curveNames = {
0x10: Hpke.CurveName.P_256,
0x11: Hpke.CurveName.P_384,
0x12: Hpke.CurveName.P_521,
0x20: Hpke.CurveName.X25519,
0x21: Hpke.CurveName.X448,
}
aeadAlgs = {
0x0001: Hpke.AeadAlg.AES_128_GCM,
0x0002: Hpke.AeadAlg.AES_256_GCM,
0x0003: Hpke.AeadAlg.CHACHA20_POLY1305,
0xFFFF: None, # TODO: deal with 0xFFFF Export-only
}
# KEM dictionary RFC9180 Table 2
kems = {
0x10: {"name": "DHKEM(P-256, HKDF-SHA256)", "Nsecret": 32, "Nenc": 65, "Npk": 65, "Nsk": 32},
0x11: {"name": "DHKEM(P-384, HKDF-SHA384)", "Nsecret": 48, "Nenc": 97, "Npk": 97, "Nsk": 48},
0x12: {"name": "DHKEM(P-521, HKDF-SHA512)", "Nsecret": 64, "Nenc": 133, "Npk": 133, "Nsk": 66},
0x20: {"name": "DHKEM(X25519, HKDF-SHA256)", "Nsecret": 32, "Nenc": 32, "Npk": 32, "Nsk": 32},
0x21: {"name": "DHKEM(X448, HKDF-SHA512)", "Nsecret": 64, "Nenc": 56, "Npk": 56, "Nsk": 56},
}
# KDF dictionary RFC9180 Table 3
kdfs = {
0x0001: {"name": "HDKF-SHA256", "Nh": 32},
0x0002: {"name": "HDKF-SHA384", "Nh": 48},
0x0003: {"name": "HDKF-SHA512", "Nh": 64},
}
# AEAD dictionary RFC9180 Table 5
aeads = {
0x0001: {"name": "AES-128-GCM", "Nk": 16, "Nn": 12, "Nt": 16},
0x0002: {"name": "AES-256-GCM", "Nk": 32, "Nn": 12, "Nt": 16},
0x0003: {"name": "ChaCha20Poly1305", "Nk": 32, "Nn": 12, "Nt": 16},
0xFFFF: {"name": "Export-only", "Nk": None, "Nn": None, "Nt": None},
}
# CONSTANTS
mode_base = 0x00
mode_psk = 0x01
mode_auth = 0x02
mode_auth_psk = 0x03
# Utility functions
def I2OSP(x, w):
""" Convert non-negative integer x to a w-length,
big-endian byte string, as described in [RFC8017].
"""
if w < 0:
raise ValueError('length w must be a nonegative integer')
s = bytearray([0]) * w
i = w - 1
while x > 0 and i >= 0:
s[i] = x & 0xFF
x >>= 8
i -= 1
return s
def OS2IP(x):
"""Convert byte string x to a non-negative integer,
as described in [RFC8017], assuming big-endian byte order."""
xlen = len(x)
n = 0
for i in range(xlen):
# n += int(x[i]) * pow(256, xlen-i-1)
if (x[i]):
n += int(x[i]) * (1 << (8 * (xlen - i - 1)))
return n
def xor_bytes(b1, b2):
x = [a ^ b for a, b in zip(b1, b2)]
return bytes(x)
def ComputeNonce(seq, base_nonce, Nn):
seq_bytes = I2OSP(seq, Nn)
return xor_bytes(base_nonce, seq_bytes)
def wrap(text, width=70):
"""Simple textwrap to display string."""
return "\n".join([text[i:i + width] for i in range(0, len(text) - 1, width)])
def wrap_hex(b, width=70):
"""Encode byte array in wrapped hex-encoded string."""
return wrap(Cnv.tohex(b).lower(), width)
def VerifyPSKInputs(mode, psk, psk_id):
"""Copied from RFC9180"""
default_psk = b""
default_psk_id = b""
got_psk = (psk != default_psk)
got_psk_id = (psk_id != default_psk_id)
if got_psk != got_psk_id:
raise Exception("Inconsistent PSK inputs")
if got_psk and (mode in [mode_base, mode_auth]):
raise Exception("PSK input provided when not needed")
if (not got_psk) and (mode in [mode_psk, mode_auth_psk]):
raise Exception("Missing required PSK input")
class CipherSuite():
def __init__(self, mode, kem_id, kdf_id, aead_id):
self.mode = mode
self.kem_id = kem_id
self.kdf_id = kdf_id
self.aead_id = aead_id
self.curveName = curveNames[kem_id]
self.aeadAlg = aeadAlgs[aead_id] # TODO 0xFFFF
def name(self):
# Make string of form "DHKEM(X25519, HKDF-SHA256), HKDF-SHA256, AES-128-GCM"
s = f"{kems[self.kem_id]['name']}, {kdfs[self.kdf_id]['name']}, {aeads[self.aead_id]['name']}"
return s
def DeriveKeyPair(self, ikm):
# sk and pk are "internal" representations valid only for this session
# In this implementation, the curve group defines all parameters to derive the key
sk = Hpke.derive_private_key(ikm, self.curveName)
pk = Ecc.publickey_from_private(sk)
# return Ecc.query_key(sk, "privateKey"), Ecc.query_key(pk, "publicKey")
return sk, pk
@staticmethod
def SerializePublicKey(pk):
""" Serialize the public key.
:param pk: public key in internal string form
:return: serialized key in bytes
"""
keyhex = Ecc.query_key(pk, "publicKey")
return Cnv.fromhex(keyhex)
@staticmethod
def SerializePrivateKey(sk):
""" Serialize the private key.
:param sk: private key in internal string form
:return: serialized key in bytes
"""
keyhex = Ecc.query_key(sk, "privateKey")
return Cnv.fromhex(keyhex)
def DeserializePublicKey(self, pk):
""" Deserialize the public key.
:param pk: public key in bytes form
:return: Public key in internal string form
"""
intkey = Ecc.read_key_by_curve(Cnv.tohex(pk), self.curveName, ispublic=True)
return intkey
def DeserializePrivateKey(self, sk):
""" Deserialize the private key.
:param sk: private key in bytes form
:return: Private key in internal string form
"""
intkey = Ecc.read_key_by_curve(Cnv.tohex(sk), self.curveName, ispublic=False)
return intkey
@staticmethod
def verifyPrivateKeysEqual(skhex, sk):
# CAUTION: the sk returned by DerivePrivateKey is *not* clamped.
# This is to match the test_vectors which are also not clamped
# Ref: https://www.rfc-editor.org/errata_search.php?rfc=9180
# https://www.rfc-editor.org/errata/eid7121
# NOTE: to be changed if the test vectors are ever updated.
keyhex = skhex
# if self.curveName == "X25519" or self.curveName == "X448":
# intkey = Ecc.read_key_by_curve(skhex, self.curveName, ispublic=False)
# keyhex = Ecc.query_key(intkey, "privateKey")
# else:
# keyhex = skhex
assert keyhex.lower() == Cnv.tohex(sk).lower()
@staticmethod
def verifyPublicKeysEqual(pkhex, pk):
# pk is in byte form
keyhex = pkhex
assert keyhex.lower() == Cnv.tohex(pk).lower()
@staticmethod
def validatePublicKey(pk):
"""Validate public key."""
r = Ecc.query_key(pk, "isValid")
assert (bool(r))
def DH(self, skX, pkY):
# Keys are in internal key form
# First, validate the public key provided
self.validatePublicKey(pkY)
# Compute DH shared secret
dh = Ecc.dh_shared_secret(skX, pkY)
return dh
def ExtractAndExpand(self, dh, kem_context):
eae_prk = Hpke.labeled_extract(b'', "eae_prk", dh, self.curveName)
Nsecret = kems[self.kem_id]['Nsecret']
shared_secret = Hpke.labeled_expand(Nsecret, eae_prk, "shared_secret", kem_context, self.curveName)
return shared_secret
def Decap(self, enc, skR):
pkE = self.DeserializePublicKey(enc)
dh = self.DH(skR, pkE)
pkR = Ecc.publickey_from_private(skR)
pkRm = self.SerializePublicKey(pkR)
kem_context = enc + pkRm # In bytes
shared_secret = self.ExtractAndExpand(dh, kem_context)
return shared_secret
def verifyTestVector(tv):
print("\nTestVector:")
suite = CipherSuite(tv['mode'], tv['kem_id'], tv['kdf_id'], tv['aead_id'])
print(suite.name())
print("mode:", suite.mode)
print("kem_id:", suite.kem_id)
print("kdf_id:", suite.kdf_id)
print("aead_id:", suite.aead_id)
print("info:", tv['info'])
# Derive keys from ikm...
print("ikmE:", wrap(tv['ikmE']))
skE, pkE = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmE']))
pkEm = suite.SerializePublicKey(pkE)
skEm = suite.SerializePrivateKey(skE)
print("pkEm:", wrap_hex(pkEm), "skEm:", wrap_hex(skEm), sep='\n')
suite.verifyPublicKeysEqual(tv['pkEm'], pkEm)
suite.verifyPrivateKeysEqual(tv['skEm'], skEm)
print("ikmR:", wrap(tv['ikmR']))
skR, pkR = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmR']))
pkRm = suite.SerializePublicKey(pkR)
skRm = suite.SerializePrivateKey(skR)
print("pkRm:", wrap_hex(pkRm), "skRm:", wrap_hex(skRm), sep='\n')
suite.verifyPublicKeysEqual(tv['pkRm'], pkRm)
suite.verifyPrivateKeysEqual(tv['skRm'], skRm)
# Create Context intermediate values and outputs
# Includes KEM outputs enc and shared_secret used to create the context
# shared_secret, enc = Encap(pkR)
# Compute DH shared secret both ways and check they are equal
dh = suite.DH(skE, pkR)
print("dh:", Cnv.tohex(dh))
dhR = suite.DH(skR, pkE)
assert (Cnv.tohex(dh) == Cnv.tohex(dhR))
# enc = SerializePublicKey(pkE)
enc = suite.SerializePublicKey(pkE)
print("enc:", wrap_hex(enc), sep='\n')
assert (Cnv.tohex(enc).lower() == tv['enc'])
pkRm = suite.SerializePublicKey(pkR)
# kem_context = concat(enc, pkRm)
kem_context = enc + pkRm # NB in byte form
# shared_secret = ExtractAndExpand(dh, kem_context)
shared_secret = suite.ExtractAndExpand(dh, kem_context)
print("shared_secret:", wrap_hex(shared_secret), sep='\n')
assert (Cnv.tohex(shared_secret).lower() == tv['shared_secret'])
# Check Decap works
shared_secret_R = suite.Decap(enc, skR)
assert (Cnv.tohex(shared_secret_R).lower() == tv['shared_secret'])
# Derive the key_schedule_context
# NB from now on, make sure to include aead when calling Hpke.labeled_extract
'''
psk_id_hash = LabeledExtract("", "psk_id_hash", psk_id)
info_hash = LabeledExtract("", "info_hash", info)
key_schedule_context = concat(mode, psk_id_hash, info_hash)
'''
info = Cnv.fromhex(tv['info'])
psk = b''
psk_id = b''
# We only do mode 0 here...
if (suite.mode != 0):
print("NOT SUPPORTED: Mode " + str(suite.mode) + " is not supported here, just mode 0")
return # So just return and try the next one
VerifyPSKInputs(suite.mode, psk, psk_id) # <== Not necessary here, but do it anyway
psk_id_hash = Hpke.labeled_extract(b'', "psk_id_hash", psk_id, suite.curveName, suite.aeadAlg)
info_hash = Hpke.labeled_extract(b'', "info_hash", info, suite.curveName, suite.aeadAlg)
key_schedule_context = I2OSP(suite.mode, 1) + psk_id_hash + info_hash
print("key_schedule_context:", wrap_hex(key_schedule_context), sep='\n')
assert (Cnv.tohex(key_schedule_context).lower() == tv['key_schedule_context'])
'''
secret = LabeledExtract(shared_secret, "secret", psk)
key = LabeledExpand(secret, "key", key_schedule_context, Nk)
base_nonce = LabeledExpand(secret, "base_nonce",
key_schedule_context, Nn)
exporter_secret = LabeledExpand(secret, "exp",
key_schedule_context, Nh)
return Context<ROLE>(key, base_nonce, 0, exporter_secret)
'''
secret = Hpke.labeled_extract(shared_secret, "secret", psk, suite.curveName, suite.aeadAlg)
print("secret:", wrap_hex(secret), sep='\n')
assert (Cnv.tohex(secret).lower() == tv['secret'])
key = Hpke.labeled_expand(aeads[suite.aead_id]["Nk"], secret, "key", key_schedule_context, suite.curveName,
suite.aeadAlg)
print("key:", wrap_hex(key), sep='\n')
assert (Cnv.tohex(key).lower() == tv['key'])
base_nonce = Hpke.labeled_expand(aeads[suite.aead_id]["Nn"], secret, "base_nonce", key_schedule_context,
suite.curveName, suite.aeadAlg)
print("base_nonce:", wrap_hex(base_nonce), sep='\n')
assert (Cnv.tohex(base_nonce).lower() == tv['base_nonce'])
exporter_secret = Hpke.labeled_expand(kdfs[suite.kdf_id]["Nh"], secret, "exp", key_schedule_context,
suite.curveName, suite.aeadAlg)
print("exporter_secret:", wrap_hex(exporter_secret), sep='\n')
assert (Cnv.tohex(exporter_secret).lower() == tv['exporter_secret'])
# Do encryptions
seq_num = 0
Nn = aeads[suite.aead_id]["Nn"]
aeadAlg = aeadAlgs[suite.aead_id]
encryptions = tv['encryptions']
for e in encryptions:
print("sequence number:", seq_num)
pt = Cnv.fromhex(e['pt'])
print("pt:", Cnv.tohex(pt))
aad = Cnv.fromhex(e['aad'])
print("aad:", Cnv.tohex(aad))
nonce = ComputeNonce(seq_num, base_nonce, Nn)
print("nonce:", Cnv.tohex(nonce))
# Sender seals (encrypts) the plaintext
ct = Cipher.encrypt_aead(pt, key, nonce, aeadAlg, aad)
print("ct:", wrap_hex(ct))
assert Cnv.tohex(ct).lower() == e['ct']
# Recipient opens (decrypts) the ciphertext
dt = Cipher.decrypt_aead(ct, key, nonce, aeadAlg, aad)
assert Cnv.tohex(dt).lower() == e['pt']
# Finally, increment sequence number
seq_num += 1
def main():
# Check min version
ver = Gen.version()
print("cryptosyspki version =", ver)
assert ver >= PKI_MIN_VERSION
print("FILE:", infile)
ntvs = 0
# Read in JSON test vectors
with open(infile) as f:
tvs = json.load(f)
print("len(test-vectors)=", len(tvs))
# Verify each test vector
for tv in tvs:
verifyTestVector(tv)
ntvs += 1
# verifyTestVector(tvs[0])
# verifyTestVector(tvs[1])
print(f"\nALL DONE. Processed {ntvs} test vectors.")
if __name__ == "__main__":
main()