/* $Id: TestSc14nPKI.c $
 * Last updated:
 *   $Date: 2019-12-13 21:23 $
 *   $Version: 2.1.0 $
 */
/* Some tests using the Sc14n C/C++ interface with CryptoSys PKI.
 * Please report any bugs to <http://cryptosys.net/contact/>
 */
/******************************* LICENSE ***********************************
* Copyright (C) 2017-19 David Ireland, DI Management Services Pty Limited.
* All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
* The code in this module is licensed under the terms of the MIT license,
* unless otherwise marked.
* For a copy, see <http://opensource.org/licenses/MIT>
****************************************************************************
*/
#if _MSC_VER >= 1100
/* Detect memory leaks in MSVC++ */
#define _CRTDBG_MAP_ALLOC
#include <stdlib.h>
#include <crtdbg.h>
#else
#include <stdlib.h>
#endif
#include <stdio.h>
#include <string.h>
#include "diSc14n.h"
#include "diCrPKI.h"

/* 
 * Requires `Sc14n` and `CryptoSys PKI` to be installed on your system,
 * available from <http://cryptosys.net/sc14n/> and <http://cryptosys.net/pki/>.
 * In particular that `diSc14n.dll` and `diCrPKI.dll` are in your library path.
 * You must link to `diSc14n.lib` and `diCrPKI.lib`. In MSVC++ IDE, use
 * Project > Configuration Properties > Linker > Input > Additional Dependencies 
 * and add the .lib file paths, e.g.
 *     Additional Dependencies = $(OutDir)diSc14n.lib;$(OutDir)diCrPKI.lib;%(AdditionalDependencies)
 * Using command-line:
 *     CL TestSc14nPKI.c /link ..\Release\diSc14n.lib  diCrPKI.lib
 *
 * Test files, e.g. `olamundo.xml`, are in `sc14n-testfiles.zip`. These must be in the CWD.
 */

#ifdef NDEBUG
/* Make sure assertion testing is turned on */
#undef NDEBUG
#endif
#include <assert.h>

// DEBUGGING UTILS
// Comment/uncomment next line to turn on/off debugging output
#define NO_DPRINTF
#if (defined(_DEBUG) && !(defined(NO_DPRINTF)))
#define DPRINTF0(s) printf(s)
#define DPRINTF1(s, a1) printf(s, a1)
#define DPRINTF2(s, a1, a2) printf(s, a1, a2)
#else
#define DPRINTF0(s) 
#define DPRINTF1(s, a1) 
#define DPRINTF2(s, a1, a2) 
#endif


// HARD-CODED PRIVATE KEY AND CERTIFICATE (FOR OUR CONVENIENCE IN TESTING)
// Alice's PKCS8 encrypted key and X.509 certificate
// from RFC 4134 "Examples of S/MIME Messages"
// Private key password is "password"
static const char *myPassword = "password";
static const char *myPriKey = "-----BEGIN ENCRYPTED PRIVATE KEY-----"
	"MIICojAcBgoqhkiG9w0BDAEDMA4ECFleZ90vhGrRAgIEAASCAoA9rti16XVH"
	"K4AJVe1CNf61NIpIogu/Xs4Yn4hXflvewiOwe6/9FkxBXLbhKdbQWn1Z4p3C"
	"njVns2VYEO/qpJR3LciHMwp5dsqedUVVia//CqFHtEV9WfvCKWgmlkkT1YEm"
	"1aChZnPP5i6IhwVT9qvFluTZhvVmjW0YyF86OrOp0uxxVic7phPbnPrOMelf"
	"ZPc3A3EGpzDPkxN+o0obw87tUgCL+s0KtUOr3c6Si4KQ3IQjrjZxQF4Se3t/"
	"4PEpqUl5EpYiCx9q5uqb0Lr1kWiiQ5/inZm5ETc+qO+ENcp0KjnX523CATYd"
	"U5iOjl/X9XZeJrMpOCXogEuhmLPRauYP1HEWnAY/hLW93v10QJXY6ALlbkL0"
	"sd5WU8Ces7T04b/p4/12yxqYqV68QePyfHpegdraDq3vRfopSwrUxtL9cisP"
	"jsQcJ5FL/SfloFbmld4CKIjMsromsEWqo6rfo3JqNizgTVIIWExy3jDT9VvK"
	"d9ADH0g3JCbuFzaWVOZMmZ0wlo28PKkLQ8FkW8CG/Lq/Q/bHLPM+sPdLN+ke"
	"gpA6fvL4wpku4ST7hmeN1vWbRLlCfuFijux77hdM7knO9/MawICsA4XdzR78"
	"p0C2hJlc6p46IWZaINQXGstTbJMh+mJ7i1lrbG2kvZ2Twf9R+RaLp2mPHjb1"
	"+P+3f2L3tOoC31oJ18u/L1MXEWxLEZHB0+ANg+N/0/icwImcI0D+wVN2puU4"
	"m58j81sGZUEAB3aFEbPxoX3y+qYlOnt1OfdY7WnNdyr9ZzI09fkrTvujF4LU"
	"nycqE+MXerf0PxkNu1qv9bQvCoH8x3J2EVdMxPBtH1Fb7SbE66cNyh//qzZo"
	"B9Je"
	"-----END ENCRYPTED PRIVATE KEY-----";
static const char *myCert = "-----BEGIN CERTIFICATE-----"
	"MIICLDCCAZWgAwIBAgIQRjRrx4AAVrwR024uxBCzsDANBgkqhkiG9w0BAQUFADAS"
	"MRAwDgYDVQQDEwdDYXJsUlNBMB4XDTk5MDkxOTAxMDg0N1oXDTM5MTIzMTIzNTk1"
	"OVowEzERMA8GA1UEAxMIQWxpY2VSU0EwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ"
	"AoGBAOCJczmN2PX16Id2OX9OsAW7U4PeD7er3H3HdSkNBS5tEt+mhibU0m+qWCn8"
	"l+z6glEPMIC+sVCeRkTxLLvYMs/GaG8H2bBgrL7uNAlqE/X3BQWT3166NVbZYf8Z"
	"f8mB5vhs6odAcO+sbSx0ny36VTq5mXcCpkhSjE7zVzhXdFdfAgMBAAGjgYEwfzAM"
	"BgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIGwDAfBgNVHSMEGDAWgBTp4JAnrHgg"
	"eprTTPJCN04irp44uzAdBgNVHQ4EFgQUd9K00bdMioqjzkWdzuw8oDrj/1AwHwYD"
	"VR0RBBgwFoEUQWxpY2VSU0FAZXhhbXBsZS5jb20wDQYJKoZIhvcNAQEFBQADgYEA"
	"PnBHqEjME1iPylFxa042GF0EfoCxjU3MyqOPzH1WyLzPbrMcWakgqgWBqE4lradw"
	"FHUv9ceb0Q7pY9Jkt8ZmbnMhVN/0uiVdfUnTlGsiNnRzuErsL2Tt0z3Sp0LF6DeK"
	"tNufZ+S9n/n+dO/q+e5jatg/SyUJtdgadq7rm9tJsCI="
	"-----END CERTIFICATE-----";


/* PKI HELPER FUNCTIONS */

/** Read private key into internal key string, valid only for this session. 
@returns Pointer to string buffer containing key string or NULL on error
@remark Caller to free allocated memory
*/
char *pki_rsaReadPrivateKey(const char *prikey, const char *password, int *pstatus)
{
	long nchars;
	char *buf;
	nchars = RSA_ReadAnyPrivateKey(NULL, 0, prikey, password, 0);
	if (nchars < 0) {
		*pstatus = nchars;
		return NULL;
	}
	buf = malloc(nchars + 1);
	nchars = RSA_ReadAnyPrivateKey(buf, nchars, prikey, password, 0);
	*pstatus = 0;
	return buf;
}

/** Compute digest value in base64 form.
@param digestbuf Buffer to receive digest value
@param bufsize Size of buffer in bytes
@param s String to be digested
@param digalg Digest algorithm flag (0 = SHA-1)
@return Number of characters in digestbuf or negative error code
*/
int pki_hashStringToBase64(char *digestbuf, size_t bufsize, const char *s, long digalg)
{
	unsigned char b[PKI_MAX_HASH_BYTES];
	long nb, nc;
	// Compute digest value in raw bytes
	nb = HASH_Bytes(b, sizeof(b), s, strlen(s), digalg);
	if (nb < 0) return nb;
	// Convert to base64 encoded value
	nc = CNV_B64StrFromBytes(digestbuf, bufsize - 1, b, nb);
	return nc;
}

/** Compute base64-encoded signature value from base64-encoded digest value. */
char *pki_SigValFromDigVal(const char *digval, const char *prikey, const char *password, int *pstatus)
{
	unsigned char b[PKI_MAX_HASH_BYTES];
	long nb, nchars;
	char *buf;
	// Convert base64-encoded value to raw bytes
	nb = CNV_BytesFromB64Str(b, sizeof(b), digval);
	// Sign the digest value
	nchars = SIG_SignData(NULL, 0, b, nb, prikey, password, "sha1WithRSAEncryption", PKI_SIG_USEDIGEST);
	if (nchars < 0) {
		*pstatus = nchars;
		return NULL;
	}
	buf = malloc(nchars + 1);
	nchars = SIG_SignData(buf, nchars, b, nb, prikey, password, "sha1WithRSAEncryption", PKI_SIG_USEDIGEST);
	*pstatus = 0;
	return buf;
}

/** Extract XML-style <RSAKeyValue> from RSA private key. */
char *pki_KeyValFromPriKey(const char *prikey, const char *password, int *pstatus)
{
	char *buf;
	char *keystr;
	long nchars;
	keystr = pki_rsaReadPrivateKey(prikey, password, pstatus);
	if (!keystr) return NULL;
	// Form XML RSAKeyValue:
	// CAUTION: do not include your private key data, just the public key
	nchars = RSA_ToXMLString(NULL, 0, keystr, PKI_XML_EXCLPRIVATE);
	if (nchars < 0) {
		free(keystr);
		*pstatus = nchars;
		return NULL;
	}
	buf = malloc(nchars + 1);
	nchars = RSA_ToXMLString(buf, nchars, keystr, PKI_XML_EXCLPRIVATE);
	free(keystr);
	*pstatus = 0;
	return buf;
}

/* FILE UTILITIES */

/** Read a binary file into a null-terminated string.
*  @return Pointer to allocated string buffer or NULL on error
*  @remark Caller must free allocated memory
*/
static unsigned char *file_to_string(const char *fname) 
{
	FILE *fp;
	char *buf;
	long flen;
	size_t nread;

	fp = fopen(fname, "rb");
	if (!fp) return NULL;
	fseek(fp, 0, SEEK_END);
	flen = ftell(fp);
	if (flen < 0) return NULL;
	buf = malloc(flen + 1);
	rewind(fp);
	nread = fread(buf, 1, flen, fp);
	fclose(fp);
	buf[nread] = '\0';

	return buf;
}

/** Write a new binary file from a string. */
static int file_from_string(const char *fname, const char *s)
{
	FILE *fpo;
	size_t nwritten;
	int r;

	fpo = fopen(fname, "wb");
	if (!fpo) return -1;
	nwritten = fwrite(s, 1, strlen(s), fpo);
	r = fclose(fpo);

	return r;	/* 0 = success */
}


/// <summary>
/// Create a XML-DSIG signed file given proforma XML document
/// </summary>
/// <param name="outFile">Name of outfile to create</param>
/// <param name="baseFile">Name of input XML document</param>
/// <param name="priKey">PKCS8 encrypted private key file or PEM-string</param>
/// <param name="password">Password for private key</param>
/// <returns>Zero (0) on success otherwise nonzero error code</returns>
/// <remarks>Input XML document is expected to be enveloped-signature with single reference URI="",
/// C14N method REC-xml-c14n-20010315, signature method xmldsig#rsa-sha1, and digest method xmldsig#sha1.
/// KeyValue is expected to be in RSAKeyValue form.
/// Items to be replaced should be marked "%digval%", "%sigval%" and "%keyval%".
/// </remarks>
int MakeSignedXml(const char *outfile, const char *basefile, const char *prikey, const char *password)
{
	/* declaration for 3rd-party function with code below */
	char *replace_str(const char *str, const char *old, const char *news);
	int status;
	long r, nchars;
	char *buf = NULL;
	char *newsi = NULL;
	char digval[SC14N_MAX_DIGEST_CHARS + 1];
	char digval_si[SC14N_MAX_DIGEST_CHARS + 1];
	char *sigval = NULL;
	char *keyval = NULL;
	char *xmlstr = NULL;
	char *newxml1 = NULL;
	char *newxml2 = NULL;
	char *newxml3 = NULL;

	// 1. Compute digest value of body excluding <Signature> element
	// (this assumes Reference URI="" and DigestMethod is SHA-1)
	r = C14N_File2Digest(digval, sizeof(digval) - 1, basefile, "Signature", "", SC14N_TRAN_OMITBYTAG);
	if (r < 0) {
		status = r;
		goto clean_up;
	}
	DPRINTF1("DIGVAL=%s\n", digval);

	// 2. Extract the SignedInfo element into memory
	// Note %digval% parameter to be completed
	nchars = C14N_File2String(NULL, 0, basefile, "SignedInfo", "", SC14N_TRAN_SUBSETBYTAG);
	if (nchars < 0) {
		status = nchars;
		goto clean_up;
	}
	buf = malloc(nchars + 1);
	nchars = C14N_File2String(buf, nchars, basefile, "SignedInfo", "", SC14N_TRAN_SUBSETBYTAG);
	DPRINTF1("SIGNEDINFO (BASE):\n%s\n", buf);

	// 3. Insert the required DigestValue we prepared earlier
	newsi = replace_str(buf, "%digval%", digval);
	DPRINTF0("SIGNEDINFO (COMPLETED):\n");
	DPRINTF1("%s\n", newsi);

	// 4. Compute the digest value of this string
	pki_hashStringToBase64(digval_si, sizeof(digval_si), newsi, PKI_HASH_SHA1);
	DPRINTF1("SHA1(signedinfo)=%s\n", digval_si);

	// 5. Compute signature from this digest value
	sigval = pki_SigValFromDigVal(digval_si, prikey, password, &status);
	if (!sigval) {
		fprintf(stderr, "ERROR: failed to create signature value");
		goto clean_up;
	}
	DPRINTF1("SIGVAL:\n%s\n", sigval);

	// 6. Get the RSA Key Value in required XML form
	// NB We can extract the public key value from the RSA private key
	keyval = pki_KeyValFromPriKey(prikey, password, &status);
	if (!keyval) {
		fprintf(stderr, "ERROR: failed to create Key value");
		goto clean_up;
	}
	DPRINTF1("KEYVAL:\n%s\n", keyval);

	// 7. Compose the output file by substituting the correct values
	// (Note we make no other checks of the input XML - that's up to you)

	// 7.1 Read in the base XML file
	xmlstr = file_to_string(basefile);
	if (!xmlstr) {
		fprintf(stderr, "ERROR: cannot read file '%s'", basefile);
		status = -1; 
		goto clean_up;
	}
	// 7.2 Substitute %% values
	newxml1 = replace_str(xmlstr, "%digval%", digval);
	newxml2 = replace_str(newxml1, "%sigval%", sigval);
	newxml3 = replace_str(newxml2, "%keyval%", keyval);

	// 7.3 Write out new string to file
	r = file_from_string(outfile, newxml3);
	DPRINTF1("file_from_string() returns %d (expected 0)\n", r);
	if (r != 0) {
		fprintf(stderr, "ERROR: failed to create file '%s'", outfile);
		status = r;
		goto clean_up;
	}

	// If we got here all is well
	status = 0;

clean_up:
	if (buf) free(buf);
	if (newsi) free(newsi);
	if (sigval) free(sigval);
	if (keyval) free(keyval);
	if (xmlstr) free(xmlstr);
	if (newxml1) free(newxml1);
	if (newxml2) free(newxml2);
	if (newxml3) free(newxml3);

	return status;
}

/* DO THE BUSINESS... */
int main(void)
{
	long n;
	char *fname, *oname;

	/* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
	_CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
	_CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
	_CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
	_CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
	_CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
	_CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE);
	_CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDOUT);
#endif

	/* General information about the core library DLLs */
	// If either of these fail, the package is not installed properly...
	printf("Sc14n Version=%ld\n", SC14N_Gen_Version());
	printf("CrPKI Version=%ld\n", PKI_Version(0, 0));

	fname = "olamundo-base.xml";
	oname = "olamundo-new-signed.xml";
	printf("FILE: %s\n", fname);
	n = MakeSignedXml(oname, fname, myPriKey, myPassword);
	printf("MakeSignedXML->'%s' returns %ld (expecting 0)\n", oname, n);

	// Input XML contains Chinese characters UTF-8-encoded
	fname = "daiwei-base.xml";
	oname = "daiwei-new-signed.xml";
	printf("FILE: %s\n", fname);
	n = MakeSignedXml(oname, fname, myPriKey, myPassword);
	printf("MakeSignedXML->'%s' returns %ld (expecting 0)\n", oname, n);

	// Input XML contains Chinese characters as character entities
	// Note that digest value and signature value should be identical to previous one
	fname = "daiwei-ents-base.xml";
	oname = "daiwei-ents-new-signed.xml";
	printf("FILE: %s\n", fname);
	n = MakeSignedXml(oname, fname, myPriKey, myPassword);
	printf("MakeSignedXML->'%s' returns %ld (expecting 0)\n", oname, n);

	printf("\nALL DONE.\n");

}


/* THIRD-PARTY CODE ********************************************************************************/

/* Ref: http://creativeandcritical.net/str-replace-c/
* Description:	Replaces in the string str all the occurrences of the source string old
* with the destination string new. The parameters old and new may be of any length,
* and their lengths are allowed to differ.
* None of the three parameters may be NULL.
*
* Returns:	The post-replacement string, or NULL if memory for the new string could not be allocated.
* Does not modify the original string. The memory for the returned post-replacement
* string may be deallocated with the standard library function free() when it is no longer required.
*
* Licence: Public domain. You may use this code in any way you see fit,
* optionally crediting its author (me, Laird Shaw, with assistance from comp.lang.c).
* http://creativeandcritical.net/contact/
*/

char *replace_str(const char *str, const char *old, const char *news)
{
	char *ret, *r;
	const char *p, *q;
	size_t oldlen = strlen(old);
	size_t count, retlen, newlen = strlen(news);

	if (oldlen != newlen) {
		for (count = 0, p = str; (q = strstr(p, old)) != NULL; p = q + oldlen)
			count++;
		/* this is undefined if p - str > PTRDIFF_MAX */
		retlen = p - str + strlen(p) + count * (newlen - oldlen);
	}
	else
		retlen = strlen(str);

	if ((ret = malloc(retlen + 1)) == NULL)
		return NULL;

	for (r = ret, p = str; (q = strstr(p, old)) != NULL; p = q + oldlen) {
		/* this is undefined if q - p > PTRDIFF_MAX */
		ptrdiff_t l = q - p;
		memcpy(r, p, l);
		r += l;
		memcpy(r, news, newlen);
		r += newlen;
	}
	strcpy(r, p);

	return ret;
}
/*****************************************************************************************************/