/*
 * Copyright 1999, Alexander Feldman <alex@varna.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Alexander Feldman nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY ALEXANDER FELDMAN AND CONTRIBUTORS ``AS IS''
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL ALEXANDER FELDMAN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include "rsa.hpp"

CRSAKey::CRSAKey()
{
	fgHoldKey = false;
}

CRSAKey::CRSAKey(Word wModulusSize, Word wPublicKey)
{
	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);
	CRSAKey::wModulusSize = wModulusSize;
	CRSAKey::PUB = CBigNumber(wPublicKey);
	GenerateKeys();
	fgHoldKey = true;
	fgEncryptOnly = false;
}

CRSAKey::CRSAKey(const CBigNumber &cModulus, const CBigNumber &cPublic)
{
	wModulusSize = cModulus.GetWords() * BITSINWORD;

	N = cModulus;
	PUB = cPublic;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = true;
}

CRSAKey::CRSAKey(const CBigNumber &cModulus,
					  const CBigNumber &cSecret,
					  const CProbablePrime &cFirstPrime,
					  const CProbablePrime &cSecondPrime)
{
	wModulusSize = cModulus.GetWords() * BITSINWORD;

	N = cModulus;
	SEC = cSecret;
	P = cFirstPrime;
	Q = cSecondPrime;
	P1 = P - 1;
	Q1 = Q - 1;
	PHI = P1 * Q1;
	U = CBigNumber::ModInv(P, Q);
	DP = SEC % P1;
	DQ = SEC % Q1;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

CRSAKey::CRSAKey(const CRSAKey &cRSAKey)
{
	wModulusSize = cRSAKey.wModulusSize;
	P = cRSAKey.P;
	Q = cRSAKey.Q;
	P1 = cRSAKey.P1;
	Q1 = cRSAKey.Q1;
	N = cRSAKey.N;
	PHI = cRSAKey.PHI;
	PUB = cRSAKey.PUB;
	SEC = cRSAKey.SEC;
	U = cRSAKey.U;
	DP = cRSAKey.DP;
	DQ = cRSAKey.DQ;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

void CRSAKey::GenerateKeys()
{
	do {
		P.SetRandom(wModulusSize / 2, true);
		while (!P.IsPrime())
			P += 2;
		P1 = P - 1;
	} while (!CBigNumber::GCD(PUB, P1).IsOne());

	do {
		Q.SetRandom(wModulusSize / 2, true);
		while (!Q.IsPrime())
			Q += 2;
		Q1 = Q - 1;
	} while (!CBigNumber::GCD(PUB, Q1).IsOne());

	N = P * Q;
	PHI = P1 * Q1;
	SEC = CBigNumber::ModInv(PUB, PHI);
	U = CBigNumber::ModInv(Q, P);
	DP = SEC % P1;
	DQ = SEC % Q1;
}

void CRSAKey::WritePrivateKey(int iOut, bool fgBase64)
{
	CDEREncodedBigNumber cVersion((Word)RSA_PRIVATE_KEY_VERSION);
	CDEREncodedBigNumber cModulus(N);
	CDEREncodedBigNumber cPublicExponent(PUB);
	CDEREncodedBigNumber cPrivateExponent(SEC);
	CDEREncodedBigNumber cPrime1(P);
	CDEREncodedBigNumber cPrime2(Q);
	CDEREncodedBigNumber cExponent1(DP);
	CDEREncodedBigNumber cExponent2(DQ);
	CDEREncodedBigNumber cCoefficient(U);

	CDEREncodedSequence cSequence;
	
	cSequence.AddPrimitive(cVersion);
	cSequence.AddPrimitive(cModulus);
	cSequence.AddPrimitive(cPublicExponent);
	cSequence.AddPrimitive(cPrivateExponent);
	cSequence.AddPrimitive(cPrime1);
	cSequence.AddPrimitive(cPrime2);
	cSequence.AddPrimitive(cExponent1);
	cSequence.AddPrimitive(cExponent2);
	cSequence.AddPrimitive(cCoefficient);

	if (true == fgBase64) {
		write_string(iOut, BEGIN_RSA_PRIVATE_KEY "\n");
		cSequence.WriteBase64(iOut, true);
		write_string(iOut, "\n" END_RSA_PRIVATE_KEY "\n");
	} else {
		cSequence.Write(iOut);
	}
}

void CRSAKey::WritePublicKey(int iOut, bool fgBase64)
{
	CDEREncodedBigNumber cModulus(N);
	CDEREncodedBigNumber cPublicExponent(PUB);

	CDEREncodedSequence cSequence;
	
	cSequence.AddPrimitive(cModulus);
	cSequence.AddPrimitive(cPublicExponent);

	if (true == fgBase64) {
		write_string(iOut, BEGIN_RSA_PUBLIC_KEY "\n");
		cSequence.WriteBase64(iOut, true);
		write_string(iOut, "\n" END_RSA_PUBLIC_KEY "\n");
	} else {
		cSequence.Write(iOut);
	}
}

void CRSAKey::ReadPrivateKey(int iIn, bool fgBase64)
{
	CDEREncodedSequence cSequence;
	if (true == fgBase64) {
		if (false == match_string(iIn, BEGIN_RSA_PRIVATE_KEY "\n"))
			throw(KEYFILE_ERROR);
		cSequence.ReadBase64(iIn);
		if (false == match_string(iIn, "\n" END_RSA_PRIVATE_KEY "\n"))
			throw(KEYFILE_ERROR);
	} else {
		cSequence.Read(iIn);
	}

	CDEREncodedBigNumber cVersion(cSequence);
	CDEREncodedBigNumber cModulus(cSequence);
	CDEREncodedBigNumber cPublicExponent(cSequence);
	CDEREncodedBigNumber cPrivateExponent(cSequence);
	CDEREncodedBigNumber cPrime1(cSequence);
	CDEREncodedBigNumber cPrime2(cSequence);
	CDEREncodedBigNumber cExponent1(cSequence);
	CDEREncodedBigNumber cExponent2(cSequence);
	CDEREncodedBigNumber cCoefficient(cSequence);

	N = cModulus;
	PUB = cPublicExponent;
	SEC = cPrivateExponent;
	P = cPrime1;
	Q = cPrime2;
	DP = cExponent1;
	DQ = cExponent2;
	U = cCoefficient;

	wModulusSize = cModulus.GetWords() * BITSINWORD;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = false;
}

void CRSAKey::ReadPublicKey(int iIn, bool fgBase64)
{
	CDEREncodedSequence cSequence;
	if (true == fgBase64) {
		if (false == match_string(iIn, BEGIN_RSA_PUBLIC_KEY "\n"))
			throw(KEYFILE_ERROR);
		cSequence.ReadBase64(iIn);
		if (false == match_string(iIn, "\n" END_RSA_PUBLIC_KEY "\n"))
			throw(KEYFILE_ERROR);
	} else {
		cSequence.Read(iIn);
	}

	CDEREncodedBigNumber cModulus(cSequence);
	CDEREncodedBigNumber cPublicExponent(cSequence);

	N = cModulus;
	PUB = cPublicExponent;

	wModulusSize = cModulus.GetWords() * BITSINWORD;

	if ((wModulusSize < 512) || (wModulusSize > 3072))
		throw(BAD_RSA_MODULUSSIZE);

	fgHoldKey = true;
	fgEncryptOnly = true;
}

void CRSAKey::Dump()
{
	printf("p [%d bits] = ", P.GetWords() * BITSINWORD); P.Dump();
	printf("q [%d bits] = ", Q.GetWords() * BITSINWORD); Q.Dump();
	printf("p - q [%d bits] = ", (P - Q).GetWords() * BITSINWORD); (P - Q).Dump();
	printf("n [%d bits] = ", N.GetWords() * BITSINWORD); N.Dump();
	printf("p - 1 [%d bits] = ", P1.GetWords() * BITSINWORD); P1.Dump();
	printf("q - 1 [%d bits] = ", Q1.GetWords() * BITSINWORD); Q1.Dump();
	printf("(p - 1)(q - 1) [%d bits] = ", PHI.GetWords() * BITSINWORD); PHI.Dump();
	printf("Public key [%d bits]: ", PUB.GetWords() * BITSINWORD); PUB.Dump();
	printf("Secret key [%d bits]: ", SEC.GetWords() * BITSINWORD); SEC.Dump();
}

Word CRSAKey::Check()
{
	if (!fgHoldKey || fgEncryptOnly)
		throw(BAD_RSA_OPERATION);
	
	Word wResult = RSA_OK;
	
	if (!P.IsPrime())
		wResult |= RSA_PNOTPRIME;
	if (!Q.IsPrime())
		wResult |= RSA_QNOTPRIME;
	if (N != P * Q)
		wResult |= RSA_BADN;
	if ((PUB != 3) && (PUB != 65537))
		wResult |= RSA_BADPUB;
	if (SEC != CBigNumber::ModInv(PUB, (P - 1) * (Q - 1)))
		wResult |= RSA_BADSEC;
	if (U != CBigNumber::ModInv(Q, P))
		wResult |= RSA_BADCOEFF;
	if (DP != (SEC % (P - 1)))
		wResult |= RSA_BADDP;
	if (DQ != (SEC % (Q - 1)))
		wResult |= RSA_BADDQ;
	
	return wResult;
}

CRSABlock::CRSABlock(const CRSAKey &cRSAKey)
{
	cKey = cRSAKey;
}

CRSABlock::CRSABlock(const CRSAKey &cRSAKey, void *pvData, Word wData)
{
	cKey = cRSAKey;
	cData = CBigNumber(pvData, wData);
}

CRSABlock::CRSABlock(const CRSAKey &cRSAKey, const CBigNumber &cRSAData)
{
	cKey = cRSAKey;
	cData = cRSAData;
}

void CRSABlock::Encrypt()
{
	if (!cKey.HoldKeyFlag())
		throw(BAD_RSA_OPERATION);

	cData = CBigNumber::ModExp(cData, cKey.GetPublic(), cKey.GetModulus());
}

void CRSABlock::Decrypt()
{
	if (!cKey.HoldKeyFlag() || cKey.EncryptOnlyFlag())
		throw(BAD_RSA_OPERATION);

	CBigNumber P2 = CBigNumber::ModExp(cData % cKey.GetP(), cKey.GetDP(), cKey.GetP());
	CBigNumber Q2 = CBigNumber::ModExp(cData % cKey.GetQ(), cKey.GetDQ(), cKey.GetQ());
	if (Q2 < P2)
		Q2 += cKey.GetQ();
	cData = P2 + cKey.GetP() * (cKey.GetU() * (Q2 - P2) % cKey.GetQ());
}

void CRSABlock::Dump()
{
	printf("d = "); cData.Dump();
}

void CRSABlock::Write(int iOut)
{
	CDEREncodedBigNumber cAll(cData);
	cAll.Write(iOut);
}

void CRSABlock::Read(int iIn)
{
	CDEREncodedBigNumber cAll;

	cAll.Read(iIn);

	cData = cAll;
}

void CRSABlock::SetData(Byte *pbData, Word wDataLength)
{
	cData = CBigNumber(pbData, wDataLength);
}

Byte *CRSABlock::GetData()
{
	return (Byte *)(cData.GetData() + 1);
}

Word CRSABlock::GetDataSize()
{
	return cData.GetData()[0];
}
