/*
  File: rsa-operation.c

  Authors:
        Mika Kojo <mkojo@ssh.fi>
        Tatu Ylonen <ylo@cs.hut.fi>
        Tero T Mononen <tmo@ssh.fi>

  Description:

        Take on the RSA operations and key definition, modified after
        Tatu Ylonen's original SSH implementation.

        Description of the RSA algorithm can be found e.g. from the
        following sources:

  - Bruce Schneier: Applied Cryptography.  John Wiley & Sons, 1994.
  - Jennifer Seberry and Josed Pieprzyk: Cryptography: An Introduction to
    Computer Security.  Prentice-Hall, 1989.
  - Man Young Rhee: Cryptography and Secure Data Communications.  McGraw-Hill,
    1994.
  - R. Rivest, A. Shamir, and L. M. Adleman: Cryptographic Communications
    System and Method.  US Patent 4,405,829, 1983.
  - Hans Riesel: Prime Numbers and Computer Methods for Factorization.
    Birkhauser, 1994.

  Copyright:
        Copyright (c) 1995-2001 SSH Communications Security Corp, Finland.
        All rights reserved.
*/

#include "sshincludes.h"
#include "sshmp.h"
#include "sshgenmp.h"
#include "sshencode.h"
#include "sshcrypt.h"
#include "sshpk.h"
#include "rsa.h"

#ifdef WITH_RSA
/* Explicit copying routines. */
#define SSH_DEBUG_MODULE "SshCryptoRSA"

void ssh_rsa_private_key_copy(void *op_src, void **op_dest)
{
  SshRSAPrivateKey *prv_dest, *prv_src = op_src;

  if ((prv_dest = ssh_malloc(sizeof(*prv_dest))) != NULL)
    {
      /* Copy bit counts. */
      prv_dest->bits = prv_src->bits;

      *op_dest = (void *)prv_dest;

      ssh_mp_init_set(&prv_dest->n, &prv_src->n);
      ssh_mp_init_set(&prv_dest->e, &prv_src->e);
      ssh_mp_init_set(&prv_dest->d, &prv_src->d);
      ssh_mp_init_set(&prv_dest->u, &prv_src->u);
      ssh_mp_init_set(&prv_dest->p, &prv_src->p);
      ssh_mp_init_set(&prv_dest->q, &prv_src->q);
    }
}

void ssh_rsa_public_key_copy(void *op_src, void **op_dest)
{
  SshRSAPublicKey *pub_dest, *pub_src = op_src;

  if ((pub_dest = ssh_malloc(sizeof(*pub_dest))) != NULL)
    {
      ssh_mp_init_set(&pub_dest->n, &pub_src->n);
      ssh_mp_init_set(&pub_dest->e, &pub_src->e);
      pub_dest->bits = pub_src->bits;

      *op_dest = (void *)pub_dest;
    }
}

/* Initialization functions. */

void *ssh_rsa_private_key_init_action(void)
{
  SshRSAInitCtx *ctx;

  if ((ctx = ssh_malloc(sizeof(*ctx))) != NULL)
    {
      ssh_mp_init_set_ui(&ctx->n, 0);
      ssh_mp_init_set_ui(&ctx->p, 0);
      ssh_mp_init_set_ui(&ctx->q, 0);
      ssh_mp_init_set_ui(&ctx->e, 0);
      ssh_mp_init_set_ui(&ctx->d, 0);
      ssh_mp_init_set_ui(&ctx->u, 0);

      ctx->bits = 0;
    }
  return (void *)ctx;
}

void *ssh_rsa_public_key_init_action(void)
{
  return ssh_rsa_private_key_init_action();
}


void ssh_rsa_private_key_init_ctx_free(void *context)
{
  SshRSAInitCtx *ctx = context;

  ssh_mp_clear(&ctx->n);
  ssh_mp_clear(&ctx->p);
  ssh_mp_clear(&ctx->q);
  ssh_mp_clear(&ctx->e);
  ssh_mp_clear(&ctx->d);
  ssh_mp_clear(&ctx->u);

  ssh_free(ctx);
}

/* Special actions. */

const char *
ssh_rsa_action_put(void *context, va_list ap,
                   void *input_context,
                   SshCryptoType type,
                   SshPkFormat format)
{
  SshRSAInitCtx *ctx = context;
  SshMPInt temp;
  char *r;

  r = "p";
  switch (format)
    {
    case SSH_PKF_SIZE:
      if (type & SSH_CRYPTO_TYPE_PUBLIC_KEY)
        return NULL;
      ctx->bits = va_arg(ap, unsigned int);
      r = "i";
      break;
    case SSH_PKF_PRIME_P:
      if (type & SSH_CRYPTO_TYPE_PUBLIC_KEY)
        return NULL;
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->p, temp);
      break;
    case SSH_PKF_PRIME_Q:
      if (type & SSH_CRYPTO_TYPE_PUBLIC_KEY)
        return NULL;
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->q, temp);
      break;
    case SSH_PKF_MODULO_N:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->n, temp);
      break;
    case SSH_PKF_SECRET_D:
      if (type & SSH_CRYPTO_TYPE_PUBLIC_KEY)
        return NULL;
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->d, temp);
      break;
    case SSH_PKF_INVERSE_U:
      if (type & SSH_CRYPTO_TYPE_PUBLIC_KEY)
        return NULL;
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->u, temp);
      break;
    case SSH_PKF_PUBLIC_E:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(&ctx->e, temp);
      break;
    default:
      return NULL;
    }
  return r;
}

const char *
ssh_rsa_action_private_key_put(void *context, va_list ap,
                               void *input_context,
                               SshPkFormat format)
{
  return ssh_rsa_action_put(context, ap,
                            input_context,
                            SSH_CRYPTO_TYPE_PRIVATE_KEY,
                            format);
}

const char *
ssh_rsa_action_private_key_get(void *context, va_list ap,
                               void **output_context,
                               SshPkFormat format)
{
  SshRSAPrivateKey *prv = context;
  unsigned int *size;
  SshMPInt temp;
  char *r;

  r = "p";
  switch (format)
    {
    case SSH_PKF_SIZE:
      size = va_arg(ap, unsigned int *);
      *size = ssh_mp_bit_size(&prv->n);
      break;
    case SSH_PKF_PRIME_P:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->p);
      break;
    case SSH_PKF_PRIME_Q:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->q);
      break;
    case SSH_PKF_MODULO_N:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->n);
      break;
    case SSH_PKF_SECRET_D:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->d);
      break;
    case SSH_PKF_INVERSE_U:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->u);
      break;
    case SSH_PKF_PUBLIC_E:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &prv->e);
      break;
    default:
      return NULL;
    }
  return r;
}

const char *
ssh_rsa_action_public_key_put(void *context, va_list ap,
                              void *input_context,
                              SshPkFormat format)
{
  return ssh_rsa_action_put(context, ap,
                            input_context,
                            SSH_CRYPTO_TYPE_PUBLIC_KEY,
                            format);
}

const char *
ssh_rsa_action_public_key_get(void *context, va_list ap,
                              void **output_context,
                              SshPkFormat format)
{
  SshRSAPublicKey *pub = context;
  unsigned int *size;
  SshMPInt temp;
  char *r;

  r = "p";
  switch (format)
    {
    case SSH_PKF_SIZE:
      size = va_arg(ap, unsigned int *);
      *size = ssh_mp_bit_size(&pub->n);
      break;
    case SSH_PKF_MODULO_N:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &pub->n);
      break;
    case SSH_PKF_PUBLIC_E:
      temp = va_arg(ap, SshMPInt );
      ssh_mp_set(temp, &pub->e);
      break;
    default:
      return NULL;
    }
  return r;
}


/* Frees any memory associated with the private key. */

void ssh_rsa_private_key_free(void *private_key)
{
  SshRSAPrivateKey *prv = private_key;

  ssh_mp_clear(&prv->n);
  ssh_mp_clear(&prv->e);
  ssh_mp_clear(&prv->d);
  ssh_mp_clear(&prv->u);
  ssh_mp_clear(&prv->p);
  ssh_mp_clear(&prv->q);

  ssh_free(prv);
}

void ssh_rsa_public_key_free(void *public_key)
{
  SshRSAPublicKey *pub = public_key;

  ssh_mp_clear(&pub->e);
  ssh_mp_clear(&pub->n);

  ssh_free(pub);
}

/* Importing and exporting private keys. */

Boolean ssh_rsa_private_key_export(const void *private_key,
                                   unsigned char **buf,
                                   size_t *length_return)
{
  const SshRSAPrivateKey *prv = private_key;

  /* Linearize. */
  *length_return =
    ssh_encode_array_alloc(buf,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->e,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->d,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->n,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->u,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->p,
                     SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &prv->q,
                     SSH_FORMAT_END);

  return TRUE;
}

Boolean ssh_rsa_private_key_import(const unsigned char *buf,
                                   size_t len,
                                   void **private_key)
{
  SshRSAPrivateKey *prv = ssh_malloc(sizeof(*prv));

  if (prv)
    {
      /* Initialize. */
      ssh_mp_init(&prv->n);
      ssh_mp_init(&prv->e);
      ssh_mp_init(&prv->d);
      ssh_mp_init(&prv->u);
      ssh_mp_init(&prv->p);
      ssh_mp_init(&prv->q);

      /* Unlinearize. */
      if (ssh_decode_array(buf, len,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->e,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->d,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->n,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->u,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->p,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &prv->q,
                       SSH_FORMAT_END) == 0)
        {
          ssh_mp_clear(&prv->n);
          ssh_mp_clear(&prv->e);
          ssh_mp_clear(&prv->u);
          ssh_mp_clear(&prv->d);
          ssh_mp_clear(&prv->p);
          ssh_mp_clear(&prv->q);
          ssh_free(prv);
          return FALSE;
        }

      prv->bits = ssh_mp_bit_size(&prv->n);
    }
  *private_key = (void *)prv;
  return prv != NULL;
}

Boolean ssh_rsa_public_key_import(const unsigned char *buf,
                                  size_t len,
                                  void **public_key)
{
  SshRSAPublicKey *pub = ssh_malloc(sizeof(*pub));

  if (pub)
    {
      ssh_mp_init(&pub->n);
      ssh_mp_init(&pub->e);

      if (ssh_decode_array(buf, len,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &pub->e,
                       SSH_FORMAT_SPECIAL, ssh_mprz_decode_rendered, &pub->n,
                       SSH_FORMAT_END) == 0)
        {
          ssh_mp_clear(&pub->n);
          ssh_mp_clear(&pub->e);
          ssh_free(&pub);
          return FALSE;
        }

      pub->bits = ssh_mp_bit_size(&pub->n);
    }
  *public_key = pub;
  return pub != NULL;
}

Boolean ssh_rsa_public_key_export(const void *public_key,
                                  unsigned char **buf,
                                  size_t *length_return)
{
  const SshRSAPublicKey *pub = public_key;

  *length_return = ssh_encode_array_alloc(buf,
                                    SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &pub->e,
                                    SSH_FORMAT_SPECIAL, ssh_mprz_encode_rendered, &pub->n,
                                    SSH_FORMAT_END);

  return TRUE;
}

#endif /* WITH_RSA */
