# $Id: hpke_test.py $

# $Date: 2023-10-24 11:479Z $

# ****************************** LICENSE ***********************************
# Copyright (C) 2023 David Ireland, DI Management Services Pty Limited.
# All rights reserved. <https://di-mgt.com.au> <https://cryptosys.net>
# The code in this module is licensed under the terms of the MIT license.
# @license MIT
# For a copy, see <http://opensource.org/licenses/MIT>
# **************************************************************************

"""
This is a "proof-of-concept" program to show that cryptosyspki v22.0 has all the features necessary to carry
out the required computations for RFC9180 "Hybrid Public Key Encryption".
It reads a cut-down version of the test-vectors.json file and computes and verifies the values for mode 0.

EC keys (skX, pkX) are stored in cryptosyspki ephemeral "internal" key string form (treat these as "blobs" valid only for this session).
The serialized form is in a byte array; use ``Cnv.tohex(key)`` to print.
There is some confusion over clamping X25519 and X448 serialized private keys - the RFC says they MUST be masked/clamped
but the test vectors do not do that. For this version, cryptosyspki does NOT clamp the serialized private keys.
See https://www.rfc-editor.org/errata_search.php?rfc=9180 https://www.rfc-editor.org/errata/eid7121

Requires installation of CryptoSys PKI v22.0 or later available from
https://cryptosys.net/pki/
and installation of the Python package ``cryptosyspki`` available from PyPi
https://pypi.org/project/cryptosyspki/

    ``pip install cryptosyspki``
"""


from cryptosyspki import *
import json

PKI_MIN_VERSION = 220000
infile = "test-vectors-1.json"

# LOOKUP TABLES
curveNames = {
    0x10: Hpke.CurveName.P_256,
    0x11: Hpke.CurveName.P_384,
    0x12: Hpke.CurveName.P_521,
    0x20: Hpke.CurveName.X25519,
    0x21: Hpke.CurveName.X448,
}
aeadAlgs = {
    0x0001: Hpke.AeadAlg.AES_128_GCM,
    0x0002: Hpke.AeadAlg.AES_256_GCM,
    0x0003: Hpke.AeadAlg.CHACHA20_POLY1305,
    0xFFFF: None,  # TODO: deal with 0xFFFF 	Export-only
}

# KEM dictionary RFC9180 Table 2
kems = {
    0x10: {"name": "DHKEM(P-256, HKDF-SHA256)", "Nsecret": 32, "Nenc": 65, "Npk": 65, "Nsk": 32},
    0x11: {"name": "DHKEM(P-384, HKDF-SHA384)", "Nsecret": 48, "Nenc": 97, "Npk": 97, "Nsk": 48},
    0x12: {"name": "DHKEM(P-521, HKDF-SHA512)", "Nsecret": 64, "Nenc": 133, "Npk": 133, "Nsk": 66},
    0x20: {"name": "DHKEM(X25519, HKDF-SHA256)", "Nsecret": 32, "Nenc": 32, "Npk": 32, "Nsk": 32},
    0x21: {"name": "DHKEM(X448, HKDF-SHA512)", "Nsecret": 64, "Nenc": 56, "Npk": 56, "Nsk": 56},
}
# KDF dictionary RFC9180 Table 3
kdfs = {
    0x0001: {"name": "HDKF-SHA256", "Nh": 32},
    0x0002: {"name": "HDKF-SHA384", "Nh": 48},
    0x0003: {"name": "HDKF-SHA512", "Nh": 64},
}
# AEAD dictionary RFC9180 Table 5
aeads = {
    0x0001: {"name": "AES-128-GCM", "Nk": 16, "Nn": 12, "Nt": 16},
    0x0002: {"name": "AES-256-GCM", "Nk": 32, "Nn": 12, "Nt": 16},
    0x0003: {"name": "ChaCha20Poly1305", "Nk": 32, "Nn": 12, "Nt": 16},
    0xFFFF: {"name": "Export-only", "Nk": None, "Nn": None, "Nt": None},
}

# CONSTANTS
mode_base = 0x00
mode_psk = 0x01
mode_auth = 0x02
mode_auth_psk = 0x03


# Utility functions
def I2OSP(x, w):
    """ Convert non-negative integer x to a w-length,
        big-endian byte string, as described in [RFC8017].
    """
    if w < 0:
        raise ValueError('length w must be a nonegative integer')
    s = bytearray([0]) * w
    i = w - 1
    while x > 0 and i >= 0:
        s[i] = x & 0xFF
        x >>= 8
        i -= 1
    return s


def OS2IP(x):
    """Convert byte string x to a non-negative integer,
    as described in [RFC8017], assuming big-endian byte order."""
    xlen = len(x)
    n = 0
    for i in range(xlen):
        # n += int(x[i]) * pow(256, xlen-i-1)
        if (x[i]):
            n += int(x[i]) * (1 << (8 * (xlen - i - 1)))
    return n


def xor_bytes(b1, b2):
    x = [a ^ b for a, b in zip(b1, b2)]
    return bytes(x)


def ComputeNonce(seq, base_nonce, Nn):
    seq_bytes = I2OSP(seq, Nn)
    return xor_bytes(base_nonce, seq_bytes)


def wrap(text, width=70):
    """Simple textwrap to display string."""
    return "\n".join([text[i:i + width] for i in range(0, len(text) - 1, width)])


def wrap_hex(b, width=70):
    """Encode byte array in wrapped hex-encoded string."""
    return wrap(Cnv.tohex(b).lower(), width)


def VerifyPSKInputs(mode, psk, psk_id):
    """Copied from RFC9180"""
    default_psk = b""
    default_psk_id = b""
    got_psk = (psk != default_psk)
    got_psk_id = (psk_id != default_psk_id)
    if got_psk != got_psk_id:
        raise Exception("Inconsistent PSK inputs")

    if got_psk and (mode in [mode_base, mode_auth]):
        raise Exception("PSK input provided when not needed")
    if (not got_psk) and (mode in [mode_psk, mode_auth_psk]):
        raise Exception("Missing required PSK input")


class CipherSuite():
    def __init__(self, mode, kem_id, kdf_id, aead_id):
        self.mode = mode
        self.kem_id = kem_id
        self.kdf_id = kdf_id
        self.aead_id = aead_id
        self.curveName = curveNames[kem_id]
        self.aeadAlg = aeadAlgs[aead_id]  # TODO 0xFFFF

    def name(self):
        # Make string of form "DHKEM(X25519, HKDF-SHA256), HKDF-SHA256, AES-128-GCM"
        s = f"{kems[self.kem_id]['name']}, {kdfs[self.kdf_id]['name']}, {aeads[self.aead_id]['name']}"
        return s

    def DeriveKeyPair(self, ikm):
        # sk and pk are "internal" representations valid only for this session
        # In this implementation, the curve group defines all parameters to derive the key
        sk = Hpke.derive_private_key(ikm, self.curveName)
        pk = Ecc.publickey_from_private(sk)
        # return Ecc.query_key(sk, "privateKey"), Ecc.query_key(pk, "publicKey")
        return sk, pk

    @staticmethod
    def SerializePublicKey(pk):
        """ Serialize the public key.
        :param pk: public key in internal string form
        :return: serialized key in bytes
        """
        keyhex = Ecc.query_key(pk, "publicKey")
        return Cnv.fromhex(keyhex)

    @staticmethod
    def SerializePrivateKey(sk):
        """ Serialize the private key.
        :param sk: private key in internal string form
        :return: serialized key in bytes
        """
        keyhex = Ecc.query_key(sk, "privateKey")
        return Cnv.fromhex(keyhex)

    def DeserializePublicKey(self, pk):
        """ Deserialize the public key.
        :param pk: public key in bytes form
        :return: Public key in internal string form
        """
        intkey = Ecc.read_key_by_curve(Cnv.tohex(pk), self.curveName, ispublic=True)
        return intkey

    def DeserializePrivateKey(self, sk):
        """ Deserialize the private key.
        :param sk: private key in bytes form
        :return: Private key in internal string form
        """
        intkey = Ecc.read_key_by_curve(Cnv.tohex(sk), self.curveName, ispublic=False)
        return intkey

    @staticmethod
    def verifyPrivateKeysEqual(skhex, sk):
        # CAUTION: the sk returned by DerivePrivateKey is *not* clamped.
        # This is to match the test_vectors which are also not clamped
        # Ref: https://www.rfc-editor.org/errata_search.php?rfc=9180
        # https://www.rfc-editor.org/errata/eid7121
        # NOTE: to be changed if the test vectors are ever updated.
        keyhex = skhex
        # if self.curveName == "X25519" or self.curveName == "X448":
        #     intkey = Ecc.read_key_by_curve(skhex, self.curveName, ispublic=False)
        #     keyhex = Ecc.query_key(intkey, "privateKey")
        # else:
        #     keyhex = skhex
        assert keyhex.lower() == Cnv.tohex(sk).lower()

    @staticmethod
    def verifyPublicKeysEqual(pkhex, pk):
        # pk is in byte form
        keyhex = pkhex
        assert keyhex.lower() == Cnv.tohex(pk).lower()

    @staticmethod
    def validatePublicKey(pk):
        """Validate public key."""
        r = Ecc.query_key(pk, "isValid")
        assert (bool(r))

    def DH(self, skX, pkY):
        # Keys are in internal key form
        # First, validate the public key provided
        self.validatePublicKey(pkY)
        # Compute DH shared secret
        dh = Ecc.dh_shared_secret(skX, pkY)
        return dh

    def ExtractAndExpand(self, dh, kem_context):
        eae_prk = Hpke.labeled_extract(b'', "eae_prk", dh, self.curveName)
        Nsecret = kems[self.kem_id]['Nsecret']
        shared_secret = Hpke.labeled_expand(Nsecret, eae_prk, "shared_secret", kem_context, self.curveName)
        return shared_secret

    def Decap(self, enc, skR):
        pkE = self.DeserializePublicKey(enc)
        dh = self.DH(skR, pkE)
        pkR = Ecc.publickey_from_private(skR)
        pkRm = self.SerializePublicKey(pkR)
        kem_context = enc + pkRm  # In bytes
        shared_secret = self.ExtractAndExpand(dh, kem_context)
        return shared_secret


def verifyTestVector(tv):
    print("\nTestVector:")
    suite = CipherSuite(tv['mode'], tv['kem_id'], tv['kdf_id'], tv['aead_id'])
    print(suite.name())
    print("mode:", suite.mode)
    print("kem_id:", suite.kem_id)
    print("kdf_id:", suite.kdf_id)
    print("aead_id:", suite.aead_id)
    print("info:", tv['info'])
    # Derive keys from ikm...
    print("ikmE:", wrap(tv['ikmE']))
    skE, pkE = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmE']))
    pkEm = suite.SerializePublicKey(pkE)
    skEm = suite.SerializePrivateKey(skE)
    print("pkEm:", wrap_hex(pkEm), "skEm:", wrap_hex(skEm), sep='\n')
    suite.verifyPublicKeysEqual(tv['pkEm'], pkEm)
    suite.verifyPrivateKeysEqual(tv['skEm'], skEm)
    print("ikmR:", wrap(tv['ikmR']))
    skR, pkR = suite.DeriveKeyPair(Cnv.fromhex(tv['ikmR']))
    pkRm = suite.SerializePublicKey(pkR)
    skRm = suite.SerializePrivateKey(skR)
    print("pkRm:", wrap_hex(pkRm), "skRm:", wrap_hex(skRm), sep='\n')
    suite.verifyPublicKeysEqual(tv['pkRm'], pkRm)
    suite.verifyPrivateKeysEqual(tv['skRm'], skRm)

    # Create Context intermediate values and outputs
    # Includes KEM outputs enc and shared_secret used to create the context
    # shared_secret, enc = Encap(pkR)

    # Compute DH shared secret both ways and check they are equal
    dh = suite.DH(skE, pkR)
    print("dh:", Cnv.tohex(dh))
    dhR = suite.DH(skR, pkE)
    assert (Cnv.tohex(dh) == Cnv.tohex(dhR))

    # enc = SerializePublicKey(pkE)
    enc = suite.SerializePublicKey(pkE)
    print("enc:", wrap_hex(enc), sep='\n')
    assert (Cnv.tohex(enc).lower() == tv['enc'])
    pkRm = suite.SerializePublicKey(pkR)
    # kem_context = concat(enc, pkRm)
    kem_context = enc + pkRm  # NB in byte form
    # shared_secret = ExtractAndExpand(dh, kem_context)
    shared_secret = suite.ExtractAndExpand(dh, kem_context)
    print("shared_secret:", wrap_hex(shared_secret), sep='\n')
    assert (Cnv.tohex(shared_secret).lower() == tv['shared_secret'])

    # Check Decap works
    shared_secret_R = suite.Decap(enc, skR)
    assert (Cnv.tohex(shared_secret_R).lower() == tv['shared_secret'])

    # Derive the key_schedule_context
    # NB from now on, make sure to include aead when calling Hpke.labeled_extract
    '''
    psk_id_hash = LabeledExtract("", "psk_id_hash", psk_id)
    info_hash = LabeledExtract("", "info_hash", info)
    key_schedule_context = concat(mode, psk_id_hash, info_hash)
    '''
    info = Cnv.fromhex(tv['info'])
    psk = b''
    psk_id = b''
    # We only do mode 0 here...
    if (suite.mode != 0):
        print("NOT SUPPORTED: Mode " + str(suite.mode) + " is not supported here, just mode 0")
        return  # So just return and try the next one
    VerifyPSKInputs(suite.mode, psk, psk_id)  # <== Not necessary here, but do it anyway
    psk_id_hash = Hpke.labeled_extract(b'', "psk_id_hash", psk_id, suite.curveName, suite.aeadAlg)
    info_hash = Hpke.labeled_extract(b'', "info_hash", info, suite.curveName, suite.aeadAlg)
    key_schedule_context = I2OSP(suite.mode, 1) + psk_id_hash + info_hash
    print("key_schedule_context:", wrap_hex(key_schedule_context), sep='\n')
    assert (Cnv.tohex(key_schedule_context).lower() == tv['key_schedule_context'])
    '''
    secret = LabeledExtract(shared_secret, "secret", psk)
    key = LabeledExpand(secret, "key", key_schedule_context, Nk)
    base_nonce = LabeledExpand(secret, "base_nonce",
                             key_schedule_context, Nn)
    exporter_secret = LabeledExpand(secret, "exp",
                                  key_schedule_context, Nh)
    return Context<ROLE>(key, base_nonce, 0, exporter_secret)
    '''
    secret = Hpke.labeled_extract(shared_secret, "secret", psk, suite.curveName, suite.aeadAlg)
    print("secret:", wrap_hex(secret), sep='\n')
    assert (Cnv.tohex(secret).lower() == tv['secret'])
    key = Hpke.labeled_expand(aeads[suite.aead_id]["Nk"], secret, "key", key_schedule_context, suite.curveName,
                              suite.aeadAlg)
    print("key:", wrap_hex(key), sep='\n')
    assert (Cnv.tohex(key).lower() == tv['key'])
    base_nonce = Hpke.labeled_expand(aeads[suite.aead_id]["Nn"], secret, "base_nonce", key_schedule_context,
                                     suite.curveName, suite.aeadAlg)
    print("base_nonce:", wrap_hex(base_nonce), sep='\n')
    assert (Cnv.tohex(base_nonce).lower() == tv['base_nonce'])
    exporter_secret = Hpke.labeled_expand(kdfs[suite.kdf_id]["Nh"], secret, "exp", key_schedule_context,
                                          suite.curveName, suite.aeadAlg)
    print("exporter_secret:", wrap_hex(exporter_secret), sep='\n')
    assert (Cnv.tohex(exporter_secret).lower() == tv['exporter_secret'])

    # Do encryptions
    seq_num = 0
    Nn = aeads[suite.aead_id]["Nn"]
    aeadAlg = aeadAlgs[suite.aead_id]
    encryptions = tv['encryptions']
    for e in encryptions:
        print("sequence number:", seq_num)
        pt = Cnv.fromhex(e['pt'])
        print("pt:", Cnv.tohex(pt))
        aad = Cnv.fromhex(e['aad'])
        print("aad:", Cnv.tohex(aad))
        nonce = ComputeNonce(seq_num, base_nonce, Nn)
        print("nonce:", Cnv.tohex(nonce))

        # Sender seals (encrypts) the plaintext
        ct = Cipher.encrypt_aead(pt, key, nonce, aeadAlg, aad)
        print("ct:", wrap_hex(ct))
        assert Cnv.tohex(ct).lower() == e['ct']

        # Recipient opens (decrypts) the ciphertext
        dt = Cipher.decrypt_aead(ct, key, nonce, aeadAlg, aad)
        assert Cnv.tohex(dt).lower() == e['pt']

        # Finally, increment sequence number
        seq_num += 1


def main():
    # Check min version
    ver = Gen.version()
    print("cryptosyspki version =", ver)
    assert ver >= PKI_MIN_VERSION
    print("FILE:", infile)
    ntvs = 0
    # Read in JSON test vectors
    with open(infile) as f:
        tvs = json.load(f)
    print("len(test-vectors)=", len(tvs))

    # Verify each test vector
    for tv in tvs:
        verifyTestVector(tv)
        ntvs += 1

    # verifyTestVector(tvs[0])
    # verifyTestVector(tvs[1])

    print(f"\nALL DONE. Processed {ntvs} test vectors.")


if __name__ == "__main__":
    main()