/*

  ssh-pam-client.c

  Author: Sami Lehtinen <sjl@ssh.com>

  Copyright (C) 2000 SSH Communications Security Corp, Helsinki, Finland
  All rights reserved.

  PAM (Pluggable Authentication Modules) authentication, server side.
  The PAM library expects the conversation function to be
  blocking. This program does not use the event loop because of that
  fact.
*/

#include "sshincludes.h"
#include "sshbuffer.h"
#include "sshencode.h"
#include "sshpamserver.h"
#include "auth-pam-common.h"
#include "sshbufaux.h"
#include "sshgetput.h"
#include "sshdsprintf.h"
#include <security/pam_appl.h>

#define SSH_DEBUG_MODULE "SshPAMClient"

#define READ_FD  fileno(stdin)
#define WRITE_FD fileno(stdout)

#define BUFFER_MAX_SIZE 50000

int recv_packet(int fd, SshPacketType *packet_type, SshBuffer packet);
int send_packet(int fd, SshPacketType packet_type, SshBuffer packet);

#define RECV_PACKET(packet_type_p, packet) \
recv_packet(READ_FD, (packet_type_p), (packet))

#define SEND_PACKET(packet_type, packet) \
send_packet(WRITE_FD, (packet_type), (packet))

#define SEND_ERROR(error_type, error_message)                           \
do                                                                      \
{                                                                       \
  SshBuffer buffer;                                                     \
  buffer = ssh_buffer_allocate();                                       \
  ssh_encode_buffer(buffer,                                             \
                    SSH_FORMAT_CHAR, (unsigned int)(error_type),        \
                    SSH_FORMAT_UINT32_STR, (error_message),             \
                    strlen(error_message),                              \
                    SSH_FORMAT_END);                                    \
  SEND_PACKET(SSH_PAM_ERROR, buffer);                                   \
  ssh_buffer_free(buffer);                                              \
} while(0)

#define SEND_OP_ERROR(error_type, error_message)                        \
do                                                                      \
{                                                                       \
  SshBuffer buffer;                                                     \
  buffer = ssh_buffer_allocate();                                       \
  ssh_encode_buffer(buffer,                                             \
                    SSH_FORMAT_CHAR, (unsigned int)(error_type),        \
                    SSH_FORMAT_UINT32_STR, (error_message),             \
                    strlen(error_message),                              \
                    SSH_FORMAT_END);                                    \
  SEND_PACKET(SSH_PAM_OP_ERROR, buffer);                                \
  ssh_buffer_free(buffer);                                              \
} while(0)

#define SEND_OP_SUCCESS                         \
do                                              \
{                                               \
  SshBuffer buffer;                             \
  buffer = ssh_buffer_allocate();               \
  SEND_PACKET(SSH_PAM_OP_SUCCESS, buffer);      \
  ssh_buffer_free(buffer);                      \
} while(0)

#define SEND_OP_SUCCESS_WITH_PAYLOAD(payload)                   \
do                                                              \
{                                                               \
  SshBuffer buffer;                                             \
  buffer = ssh_buffer_allocate();                               \
  ssh_encode_buffer(buffer,                                     \
                    SSH_FORMAT_UINT32_STR, (char *)(payload),   \
                    strlen((char *)(payload)),                  \
                    SSH_FORMAT_END);                            \
  SEND_PACKET(SSH_PAM_OP_SUCCESS_WITH_PAYLOAD, buffer);         \
  ssh_buffer_free(buffer);                                      \
} while(0)

typedef struct SshPAMPacketTypeRec
{
  SshPacketType packet_type;
  char *type_str;
} *SshPAMPacketType, SshPAMPacketTypestruct;

SshPAMPacketTypestruct pam_packet_types[] =
{
  {0,                     "INVALID_PACKET"},
  {SSH_PAM_START,         "SSH_PAM_START"},
  {SSH_PAM_AUTHENTICATE,  "SSH_PAM_AUTHENTICATE"},
  {SSH_PAM_ACCT_MGMT,     "SSH_PAM_ACCT_MGMT"},
  {SSH_PAM_OPEN_SESSION,  "SSH_PAM_OPEN_SESSION"},
  {SSH_PAM_CLOSE_SESSION, "SSH_PAM_CLOSE_SESSION"},
  {SSH_PAM_SETCRED,       "SSH_PAM_SETCRED"},
  {SSH_PAM_CHAUTHTOK,     "SSH_PAM_CHAUTHTOK"},
  {SSH_PAM_END,           "SSH_PAM_END"},
  {SSH_PAM_SET_ITEM,      "SSH_PAM_SET_ITEM"},
  {SSH_PAM_GET_ITEM,      "SSH_PAM_GET_ITEM"}
};

typedef struct SshPAMContextRec
{
  int dummy;
} *SshPAMContext, SshPAMContextStruct;

int ssh_pam_conversation_function(int num_msg,
                                  const struct pam_message **msg,
                                  struct pam_response **resp,
                                  void *appdata_ptr)
{
  SshBuffer packet;
  struct pam_response *reply;
  SshPacketType packet_type;
  SshUInt32 num_resp;
  unsigned int resp_retcode;
  
  int i;
  char *resp_msg;
  
  packet = ssh_buffer_allocate();

  /* PAM wants to free() the responses, and this conflicts with
     our use of malloc() in sshmalloc.c. So we have to use the
     default calloc(). */
#undef calloc
  reply = (struct pam_response *) calloc(num_msg,
                                         sizeof(struct pam_response));

  if (reply == NULL)
      goto resp_error;
    
  buffer_put_int(packet, (SshUInt32) num_msg);

  for (i = 0; i < num_msg; i++)
    {
      buffer_put_char(packet, ssh_pam_msg_style_to_ssh(msg[i]->msg_style));
      buffer_put_uint32_string(packet, msg[i]->msg,
                               (SshUInt32)strlen(msg[i]->msg));
    }

  SEND_PACKET(SSH_PAM_CONVERSATION_MSG, packet);

  ssh_buffer_clear(packet);
  
  if (RECV_PACKET(&packet_type, packet) != 0)
    {
      ssh_warning("ssh_pam_conversation_function: read failed.");
      goto resp_error;
    }

  if (packet_type != SSH_PAM_CONVERSATION_RESP)
    {
      ssh_warning("ssh_pam_conversation_function: expecting "
                  "SSH_PAM_CONVERSATION_RESP packet, got packet "
                  "of type %d.", packet_type);
      goto resp_error;
    }

  if (!ssh_decode_buffer(packet,
                         SSH_FORMAT_UINT32, &num_resp,
                         SSH_FORMAT_END))
    {
      ssh_warning("ssh_pam_conversation_function: error decoding packet.");
      goto resp_error;
    }

  if (num_resp != num_msg)
    {
      ssh_warning("ssh_pam_conversation_function: num_resp (%d) != num_msg "
                  "(%d).", (int)num_resp, num_msg);
      goto resp_error;
    }
  
  for (i = 0; i < num_msg; i++)
    {
      if (!ssh_decode_buffer(packet,
                             SSH_FORMAT_CHAR, &resp_retcode,
                             SSH_FORMAT_UINT32_STR, &resp_msg, NULL,
                             SSH_FORMAT_END))
        {
          ssh_warning("Malformed SSH_PAM_CONVERSATION_RESP packet.");
          goto resp_error_free_resps;
        }
      if (!ssh_pam_resp_retcode_to_pam(resp_retcode, &reply[i].resp_retcode))
        {
          ssh_warning("Invalid resp_retcode from client..");
          goto resp_error_free_resps;          
        }

      reply[i].resp = calloc(strlen(resp_msg) + 1, sizeof(char));
      if (reply[i].resp == NULL)
        goto resp_error_free_resps;          
        
      strncpy(reply[i].resp, resp_msg, strlen(resp_msg));
      
      ssh_xfree(resp_msg);
      resp_msg = NULL;
    }

  *resp = reply;
  reply = NULL;
  ssh_buffer_free(packet);
  return PAM_SUCCESS;
 resp_error_free_resps:
  /*   XXX */
  /*   for(i = 0; i < num_msg; i++) */
  /*     ssh_xfree(reply[i].resp); */
 resp_error:
  ssh_buffer_free(packet);
  /*   XXX */
  /*   ssh_xfree(reply); */
  return PAM_CONV_ERR;                  
}

int recv_packet(int fd, SshPacketType *packet_type, SshBuffer packet)
{
  ssize_t ret_val;
#define SSH_PAM_BUF_SIZE 1024
  unsigned char buf[SSH_PAM_BUF_SIZE];
  size_t read_bytes = 0L;
  size_t packet_len = 0L;
  
  /* Read incoming packet length. */
  while (read_bytes < 4)
    {
      ret_val = read(fd, buf, 4 - read_bytes);
      if (ret_val <= 0)
        {
          if (ret_val == 0)
            SSH_TRACE(0, ("read gave EOF."));
          else
            SSH_TRACE(0, ("read gave error '%s'.", strerror(errno)));
          return 1;
        }      
      
      read_bytes += ret_val;
    }

  packet_len = SSH_GET_32BIT(buf);

  if (packet_len < 0 || packet_len > BUFFER_MAX_SIZE)
    {
      SSH_TRACE(0, ("invalid packet len in incoming packet: %ld.",
                    packet_len));
      return 1;
    }
  
  read_bytes = 0L;
  
  while (read_bytes < packet_len)
    {
      
      ret_val = read(fd, buf, sizeof(buf) < packet_len - read_bytes ?
                     sizeof(buf) : packet_len - read_bytes);
      if (ret_val <= 0)
        {
          if (ret_val == 0)
            SSH_TRACE(0, ("read gave EOF."));
          else
            SSH_TRACE(0, ("read gave error '%s'.", strerror(errno)));
          return 1;
        }      
                     
      if (read_bytes == 0)
        {
          /* packet type. */
          *packet_type = SSH_GET_8BIT(buf);
          ssh_buffer_append(packet, buf + 1, ret_val - 1);
          read_bytes = ret_val;
          continue;
        }

      ssh_buffer_append(packet, buf, ret_val);
      read_bytes += ret_val;
    }

  return 0;
}

int send_packet(int fd, SshPacketType packet_type, SshBuffer packet)
{
  SshBuffer buffer;
  size_t packet_len = 0L, written_bytes = 0L;
  ssize_t ret_val;
  
  buffer = ssh_buffer_allocate();
  ssh_encode_buffer(buffer,
                    SSH_FORMAT_UINT32, ssh_buffer_len(packet) + 1,
                    SSH_FORMAT_CHAR, (unsigned int)packet_type,
                    SSH_FORMAT_DATA, ssh_buffer_ptr(packet),
                    ssh_buffer_len(packet),
                    SSH_FORMAT_END);

  packet_len = ssh_buffer_len(buffer);

  SSH_TRACE(2, ("Final packet len: %ld.", packet_len));

  while (written_bytes < packet_len)
    {
      ret_val = write(fd, ssh_buffer_ptr(buffer) + written_bytes,
                      packet_len - written_bytes);
      if (ret_val < 0)
        {
          SSH_TRACE(0, ("write gave error '%s'.", strerror(errno)));
          return 1;
          ssh_buffer_free(buffer);
        }
      SSH_DEBUG(3, ("Wrote %ld bytes.", ret_val));
      written_bytes += ret_val;
    }

  ssh_buffer_free(buffer);
  return 0;
}


int main(int argc,char **argv)
{
  pam_handle_t *pamh = NULL;
  SshPacketType packet_type;
  SshBuffer packet;
  struct pam_conv conv;  
  int pam_status = PAM_SUCCESS, ret_val;
  SshPAMContext ssh_pam_context;
  size_t packet_len;
  char *service_name, *user_name;
  SshUInt32 flags;
  void *item;
  
  ssh_pam_context = ssh_xcalloc(1, sizeof(*ssh_pam_context));
  conv.conv = ssh_pam_conversation_function;
  conv.appdata_ptr = ssh_pam_context;
  
  packet = ssh_buffer_allocate();

  SSH_TRACE(2, ("Starting PAM. Waiting for sshd2 to send a packet."));

  while(1)
    {
      
      ssh_buffer_clear(packet);
      
      if (RECV_PACKET(&packet_type, packet) != 0)
        {
          ssh_warning("read failed.");
          return 1;
        }

      SSH_TRACE(2, ("Received packet %s.", packet_type));
      
      pam_status = PAM_SUCCESS;
      
      switch (packet_type)
        {
#define SEND_RESULT                                             \
do {                                                            \
  if (ret_val == PAM_SUCCESS)                                   \
    /* Notify service of pam_start() success. */                \
    SEND_OP_SUCCESS;                                            \
  else                                                          \
    SEND_OP_ERROR(ret_val, pam_strerror(pamh, ret_val));        \
} while(0)

        case SSH_PAM_START:          
          packet_len = ssh_buffer_len(packet);
          
          if (ssh_decode_buffer(packet,
                                SSH_FORMAT_UINT32_STR, &service_name, NULL,
                                SSH_FORMAT_UINT32_STR, &user_name, NULL,
                                SSH_FORMAT_END) != packet_len)
            {
              SEND_ERROR(SSH_PAM_PROTOCOL_ERROR,
                         "Malformed SSH_PAM_START packet.");
              return 1;
            }
          
          ret_val = pam_start(service_name, user_name,
                              &conv, &pamh);
          SEND_RESULT;
          break;

#define CALL_PAM_FUNC_WITH_FLAGS(func)                          \
do {                                                            \
                                                                \
  packet_len = ssh_buffer_len(packet);                          \
                                                                \
  if (ssh_decode_buffer(packet,                                 \
                        SSH_FORMAT_UINT32, &flags,              \
                        SSH_FORMAT_END) != packet_len)          \
    {                                                           \
      char *error_msg;                                          \
      ssh_dsprintf(&error_msg, "Malformed %s packet.",          \
                   pam_packet_types[packet_type].type_str);     \
                                                                \
      SEND_ERROR(SSH_PAM_PROTOCOL_ERROR,                        \
                 error_msg);                                    \
      ssh_xfree(error_msg);                                     \
                                                                \
      return 1;                                                 \
    }                                                           \
                                                                \
  ret_val = (func)(pamh, (int)flags);                           \
                                                                \
  SEND_RESULT;                                                  \
} while (0)
            
        case SSH_PAM_AUTHENTICATE:
          CALL_PAM_FUNC_WITH_FLAGS(pam_authenticate);
          break;

        case SSH_PAM_ACCT_MGMT:
          CALL_PAM_FUNC_WITH_FLAGS(pam_acct_mgmt);
          
          break;

        case SSH_PAM_OPEN_SESSION:
          CALL_PAM_FUNC_WITH_FLAGS(pam_open_session);
          break;

        case SSH_PAM_SETCRED:
          CALL_PAM_FUNC_WITH_FLAGS(pam_setcred);
          break;

        case SSH_PAM_CLOSE_SESSION:
          CALL_PAM_FUNC_WITH_FLAGS(pam_close_session);
          break;
          
        case SSH_PAM_CHAUTHTOK:
          CALL_PAM_FUNC_WITH_FLAGS(pam_chauthtok);
          break;

        case SSH_PAM_END:
          CALL_PAM_FUNC_WITH_FLAGS(pam_end);
          break;

        case SSH_PAM_SET_ITEM:
          packet_len = ssh_buffer_len(packet);
          
          if (ssh_decode_buffer(packet,
                                SSH_FORMAT_UINT32, &flags,
                                SSH_FORMAT_UINT32_STR, (char **)&item, NULL,
                                SSH_FORMAT_END) != packet_len)
            {
              SEND_ERROR(SSH_PAM_PROTOCOL_ERROR,
                         "Malformed SSH_PAM_SET_ITEM packet.");
              return 1;
            }
          
          ret_val = pam_set_item(pamh, (int)flags, item);

          SEND_RESULT;
          break;
          
        case SSH_PAM_GET_ITEM:
          packet_len = ssh_buffer_len(packet);
          
          if (ssh_decode_buffer(packet,
                                SSH_FORMAT_UINT32, &flags,
                                SSH_FORMAT_END) != packet_len)
            {
              SEND_ERROR(SSH_PAM_PROTOCOL_ERROR,
                         "Malformed SSH_PAM_GET_ITEM packet.");
              return 1;
            }
          
          ret_val = pam_get_item(pamh, (int)flags, (const void **)&item);

          if (ret_val == PAM_SUCCESS)
            SEND_OP_SUCCESS_WITH_PAYLOAD(item);
          else
            SEND_OP_ERROR(ret_val, pam_strerror(pamh, ret_val));
          
          break;
          
        default:
          /* D'oh. */
          /* Send protocol error. */
          {
            char *error_msg;                                          
            ssh_dsprintf(&error_msg, "Unknown packet type %d.",
                         packet_type);
            SEND_ERROR(SSH_PAM_PROTOCOL_ERROR,
                       error_msg);
            ssh_xfree(error_msg);            
            return 1;
          }
        }
    }
  
  return 0;
}
