/*

  ssholdfsm.c

  Author: Antti Huima <huima@ssh.fi>

  Copyright (c) 1999-2001 SSH Communications Security, Finland
  All rights reserved.

  Created Thu Aug 26 14:27:14 1999.

  */

#include "sshincludes.h"
#include "ssholdfsm.h"
#include "sshdebug.h"
#include "sshtimeouts.h"

#define SSH_DEBUG_MODULE "SshOldFSM"

#undef SSH_OLDFSM_SCHEDULER_DEBUG

#define SSH_OLDFSM_HASH_TABLE_SIZE 1001
#define SSH_OLDFSM_STACK_SIZE 4

#define SSH_OLDFSM_INITIAL_FSM_FLAGS 0
#define SSH_OLDFSM_INITIAL_THREAD_FLAGS 0

/* Thread flags. */
#define SSH_OLDFSM_RUNNING 1    /* Control inside a step function. */
#define SSH_OLDFSM_IN_MESSAGE_HANDLER 2
                                /* Control inside a message handler. */
#define SSH_OLDFSM_CALLBACK_FLAG 4

/* FSM flags. */
#define SSH_OLDFSM_IN_SCHEDULER 1/* Control inside scheduler. */
#define SSH_OLDFSM_SCHEDULER_SCHEDULED 2
                                /* Scheduler scheduled for running (!) */

/* Thread statuses. */
typedef enum {
  SSH_OLDFSM_T_ACTIVE,           /* On the active list. */
  SSH_OLDFSM_T_SUSPENDED,        /* On the waiting_external list. */
  SSH_OLDFSM_T_WAITING_CONDITION /* On the waiting list of a condition var. */
} SshOldFSMThreadStatus;

typedef struct ssh_oldfsm_hash_chain {
  void *key_ptr;
  SshOldFSMStateMapItem state;
  struct ssh_oldfsm_hash_chain *next;
} SshOldFSMHashChain;

struct ssh_oldfsm_thread {
  struct ssh_oldfsm_thread *next, *prev;
                                /* Ring pointers. The thread belongs always
                                   to exactly one ring. */
  char *name;
  SshOldFSM fsm;                   /* The FSM the thread belongs to. */
  SshOldFSMThreadDestructor destructor;
                                /* A destructor for thread-specific data. */
  SshOldFSMStateMapItem current_state;
                                /* The current (next) state. */
  void *tdata_stack[SSH_OLDFSM_STACK_SIZE];
                                /* Stack of thread-specific data items. */
  int tdata_stack_size;         /* Actual stack size. */
  SshOldFSMCondition waited_condition;
                                /* A pointer to the condition variable
                                   the thread is waiting for, if any. */
  SshOldFSMMessageHandler ehandler_stack[SSH_OLDFSM_STACK_SIZE];
                                /* Message handler stack. */
  int ehandler_stack_size;      /* Stack size. */
  SshUInt32 flags;
  SshOldFSMThreadStatus status;    /* Status. */
};

struct ssh_oldfsm_cond_var {
  struct ssh_oldfsm_cond_var *next, *prev;
                                /* The ring pointers of condition
                                   variables for the particular FSM. */
  SshOldFSM fsm;                   /* The FSM the variable belongs to. */
  SshOldFSMThread waiting;         /* The ring of threads waiting for
                                   this condition variable. */
};

typedef struct ssh_oldfsm_message_blob {
  struct ssh_oldfsm_message_blob *next, *prev;
                                /* The ring of message signals. */
  SshUInt32 message;
  SshOldFSMThread recipient;
} SshOldFSMMessageBlob;

struct ssh_oldfsm {
  SshOldFSMHashChain *hash_table[SSH_OLDFSM_HASH_TABLE_SIZE];
                                /* The state names hash table. */
  SshOldFSMStateMapItem states;   /* The states array. */
  int num_states;               /* Size of the array. */
  SshOldFSMDestructor destructor;  /* Destructor for FSM-specific data. */
  void *idata;                  /* The FSM-specific data pointer. */

  /* Every thread is either in the `active' ring, in the
     `waiting_external' ring, or in the `waiting' ring of exactly one
     condition variable. */

  SshOldFSMCondition conditions;
                                /* The ring of condition variables. */
  SshOldFSMThread active;          /* The ring of active threads,
                                   i.e. those that can be run. */
  SshOldFSMThread waiting_external;
                                /* The ring of threads that are
                                   suspended and are waiting for an
                                   external callback, i.e. an
                                   externally given
                                   ssh_oldfsm_continue(...). */
  SshOldFSMMessageBlob *messages;
                                /* Signalled messages. */
  SshUInt32 flags;
  int num_threads;              /* For sanity check in ssh_oldfsm_destroy. */
};

typedef struct ssh_oldfsm_ring_object {
  struct ssh_oldfsm_ring_object *next, *prev;
  /* ... */
} SshOldFSMRingObject;

/** Ring functions. **/

#ifdef SSH_OLDFSM_SCHEDULER_DEBUG
static void ssh_oldfsm_thread_ring_dump(SshOldFSMThread thread, FILE *file)
{
  if (thread == NULL)
    {
      fprintf(file, "[no threads]");
    }
  else
    {
      int count = 1;
      SshOldFSMThread last = thread->prev;

      SSH_ASSERT(last != NULL);

      thread = last;

      do
        {
          thread = thread->next;
          fprintf(file, "%d. `%s' ", count,
                  thread->name == NULL ? "unknown" : thread->name);
          count++;
        }
      while (thread != last && count < 20);
    }
}
#endif

static void ring_add(SshOldFSMRingObject **root_ptr,
                     SshOldFSMRingObject *object)
{
  if ((*root_ptr) == NULL)
    {
      *root_ptr = object; object->next = object->prev = object;
    }
  else
    {
      object->next = (*root_ptr)->next;
      (*root_ptr)->next = object;
      object->next->prev = object;
      object->prev = *root_ptr;
    }
}

static void ring_remove(SshOldFSMRingObject **root_ptr,
                        SshOldFSMRingObject *object)
{
  if (object->next == object)
    {
      *root_ptr = NULL;
    }
  else
    {
      object->next->prev = object->prev;
      object->prev->next = object->next;
      if (*root_ptr == object)
        *root_ptr = object->next;
    }
}

static void ring_rotate(SshOldFSMRingObject **root_ptr)
{
  SSH_ASSERT(*root_ptr != NULL);
  *root_ptr = (*root_ptr)->next;
}

#define RING_ADD(r, o) \
  ring_add((SshOldFSMRingObject **)r, (SshOldFSMRingObject *)o)
#define RING_REMOVE(r, o) \
  ring_remove((SshOldFSMRingObject **)r, (SshOldFSMRingObject *)o)
#define RING_ROTATE(r) \
  ring_rotate((SshOldFSMRingObject **)r)

/** Handling symbolic state names. **/

static SshUInt32 hash_func(char *name)
{
  return (((SshUInt32)name) % SSH_OLDFSM_HASH_TABLE_SIZE);
}

static SshOldFSMStateMapItem find_state_by_name(SshOldFSM fsm, char *name)
{
  int i;
  for (i = 0; i < fsm->num_states; i++)
    {
      if (!strcmp(fsm->states[i].state_id, name))
        return &(fsm->states[i]);
    }
  ssh_fatal("find_state_by_name: cannot find a state w/ name `%s'.",
            name);
  /* Not reached, actually. */
  return NULL;
}

static SshOldFSMStateMapItem map_state(SshOldFSM fsm, char *name)
{
  SshUInt32 idx = hash_func(name);
  SshOldFSMHashChain *c = fsm->hash_table[idx];
  SshOldFSMHashChain *newp;

  while (c != NULL)
    {
      if (c->key_ptr == name) break;
      c = c->next;
    }

  if (c == NULL)                /* Not in the hash table yet */
    {
      newp = ssh_malloc(sizeof(*newp));
      if (newp)
        {
          newp->key_ptr = name;
          newp->next = fsm->hash_table[idx];
          fsm->hash_table[idx] = newp;
          newp->state = find_state_by_name(fsm, name);

          SSH_DEBUG(8, ("Added ptr %p ('%s') to hash table.", name, name));

          return newp->state;
        }
      else
        {
          /* No memory to update the state mapping.  Let's just return
             the next state.  This way we won't crash if we run out of
             memory. */
          return find_state_by_name(fsm, name);
        }
    }
  else                          /* Already in the hash table */
    {
      return c->state;
    }
}

static void clear_hash_table(SshOldFSM fsm)
{
  int i;
  SshOldFSMHashChain *t, *c;
  for (i = 0; i < SSH_OLDFSM_HASH_TABLE_SIZE; i++)
    {
      c = fsm->hash_table[i];
      while (c != NULL)
        {
          t = c;
          c = c->next;
          ssh_free(t);
        }
    }
}

/** Create an FSM object. **/

SshOldFSM ssh_oldfsm_allocate(size_t internal_data_size,
                        SshOldFSMStateMapItem states,
                        int num_states,
                        SshOldFSMDestructor destructor)
{
  SshOldFSM fsm;
  int i;

  if ((fsm = ssh_malloc(sizeof(*fsm))) != NULL)
    {
      for (i = 0; i < SSH_OLDFSM_HASH_TABLE_SIZE; i++)
        fsm->hash_table[i] = NULL;
      fsm->states = states; fsm->num_states = num_states;
      fsm->destructor = destructor;

      if ((fsm->idata = ssh_calloc(1, internal_data_size)) == NULL)
        {
          ssh_free(fsm);
          return NULL;
        }

      fsm->conditions = NULL;
      fsm->active = NULL;
      fsm->waiting_external = NULL;
      fsm->flags = SSH_OLDFSM_INITIAL_FSM_FLAGS;
      fsm->num_threads = 0;
      fsm->messages = 0;
    }
  return fsm;
}

/** Move threads. **/

static void move_thread(SshOldFSMThread *from_ring,
                        SshOldFSMThread *to_ring,
                        SshOldFSMThread thread)
{
  RING_REMOVE(from_ring, thread);
  RING_ADD(to_ring, thread);
}

/** Delete a thread. **/

static void delete_thread(SshOldFSMThread thread)
{
  thread->fsm->num_threads--;

  ssh_free(thread->name);
  if (thread->destructor != NULL)
    (*(thread->destructor))(thread->tdata_stack[0]);
  ssh_free(thread->tdata_stack[0]);
  ssh_free(thread);
}

/** Internal dispatcher, scheduler, whatever. **/

static void scheduler(SshOldFSM fsm)
{
  /* No recursive invocations! */
  if (fsm->flags & SSH_OLDFSM_IN_SCHEDULER)
    return;

  SSH_DEBUG(8, ("Entering the scheduler."));
  SSH_DEBUG_INDENT;

  fsm->flags |= SSH_OLDFSM_IN_SCHEDULER;

  while (1)
    {
      SshOldFSMThread thread;
      SshOldFSMStepStatus status;
      SshOldFSMMessageBlob *blob;

#ifdef SSH_OLDFSM_SCHEDULER_DEBUG
      ssh_oldfsm_thread_ring_dump(fsm->active, stderr);
      fprintf(stderr, "\n");
#endif

      if (fsm->messages != NULL)
        {
          blob = fsm->messages;
          RING_REMOVE(&(fsm->messages), blob);

          if (blob->recipient->ehandler_stack
              [blob->recipient->ehandler_stack_size - 1] != NULL)
            {
              SSH_DEBUG(8, ("Delivering the message %u to thread `%s'.",
                            blob->message,
                            blob->recipient->name != NULL ?
                            blob->recipient->name : "unknown"));

              (*(blob->recipient->ehandler_stack
                 [blob->recipient->ehandler_stack_size - 1]))
                (blob->recipient, blob->message);
            }
          ssh_free(blob);

          continue;
        }

      if (fsm->active == NULL)
        {
          SSH_DEBUG_UNINDENT;
          SSH_DEBUG(6, ("No active threads so return from scheduler."));
          fsm->flags &= ~SSH_OLDFSM_IN_SCHEDULER;
          break;
        }

      thread = fsm->active;
      RING_REMOVE(&(fsm->active), thread);
      SSH_ASSERT(thread->status == SSH_OLDFSM_T_ACTIVE);

      SSH_ASSERT(!(thread->flags & SSH_OLDFSM_RUNNING));
      thread->flags |= SSH_OLDFSM_RUNNING;

      SSH_DEBUG(8, ("Thread continuing from state `%s' (%s).",
                    thread->current_state->state_id,
                    thread->current_state->descr));

      status = (*(thread->current_state->func))(thread);

      thread->flags &= ~SSH_OLDFSM_RUNNING;

      switch (status)
        {
        case SSH_OLDFSM_FINISH:
          SSH_DEBUG(8, ("Thread finished in state `%s'.",
                        thread->current_state->state_id));
          delete_thread(thread);
          break;

        case SSH_OLDFSM_SUSPENDED:
          SSH_DEBUG(8, ("Thread suspended in state `%s'.",
                        thread->current_state->state_id));
          thread->status = SSH_OLDFSM_T_SUSPENDED;
          RING_ADD(&(fsm->waiting_external), thread);
          break;

        case SSH_OLDFSM_WAIT_CONDITION:
          SSH_DEBUG(8, ("Thread waiting for a condition variable in "
                        "state `%s'.",
                        thread->current_state->state_id));
          /* Already added to the condition variable's ring. */
          break;

        case SSH_OLDFSM_CONTINUE:
          RING_ADD(&(fsm->active), thread);
          RING_ROTATE(&(fsm->active));
          break;

        default:
          SSH_NOTREACHED;
        }
    }
}

static void scheduler_callback(void *ctx)
{
  ((SshOldFSM)ctx)->flags &= ~SSH_OLDFSM_SCHEDULER_SCHEDULED;
  scheduler((SshOldFSM)ctx);
}

static void schedule_scheduler(SshOldFSM fsm)
{
  if (!(fsm->flags & (SSH_OLDFSM_IN_SCHEDULER |
                      SSH_OLDFSM_SCHEDULER_SCHEDULED)))
    {
      fsm->flags |= SSH_OLDFSM_SCHEDULER_SCHEDULED;
      ssh_register_timeout(0L, 0L, scheduler_callback, (void *)fsm);
    }
}

void ssh_oldfsm_continue(SshOldFSMThread thread)
{
  if (thread->status == SSH_OLDFSM_T_SUSPENDED)
    {
      SSH_DEBUG(8, ("Reactivating a suspended thread."));
      thread->status = SSH_OLDFSM_T_ACTIVE;
      move_thread(&(thread->fsm->waiting_external),
                  &(thread->fsm->active),
                  thread);
      schedule_scheduler(thread->fsm);
      return;
    }

  if (thread->status == SSH_OLDFSM_T_WAITING_CONDITION)
    {
      SSH_DEBUG(8, ("Reactivating a thread waiting for a condition variable "
                    "(detaching from the condition)."));
      thread->status = SSH_OLDFSM_T_ACTIVE;
      move_thread(&(thread->waited_condition->waiting),
                  &(thread->fsm->active),
                  thread);
      schedule_scheduler(thread->fsm);
    }

  if (thread->status == SSH_OLDFSM_T_ACTIVE)
    {
      SSH_DEBUG(8, ("Reactivating an already active thread (do nothing)."));
      return;
    }

  SSH_NOTREACHED;
}


SshOldFSMThread ssh_oldfsm_spawn(SshOldFSM fsm,
                           size_t internal_data_size,
                           char *first_state,
                           SshOldFSMMessageHandler ehandler,
                           SshOldFSMThreadDestructor destructor)
{
  SshOldFSMThread thread;

  SSH_DEBUG(8, ("Spawning a new thread starting from `%s'.",
                first_state));


  if ((thread = ssh_malloc(sizeof(*thread))) != NULL)
    {
      if ((thread->tdata_stack[0] = ssh_malloc(internal_data_size)) == NULL)
        {
          ssh_free(thread);
          return NULL;
        }

      thread->tdata_stack_size = 1;
      thread->ehandler_stack[0] = ehandler;
      thread->ehandler_stack_size = 1;
      thread->fsm = fsm;
      thread->destructor = destructor;
      thread->flags = SSH_OLDFSM_INITIAL_THREAD_FLAGS;
      thread->name = NULL;
      thread->waited_condition = NULL;

      fsm->num_threads++;
      ssh_oldfsm_set_next(thread, first_state);

      RING_ADD(&(fsm->active), thread);
      thread->status = SSH_OLDFSM_T_ACTIVE;

      schedule_scheduler(fsm);
    }
  return thread;
}

static void destroy_callback(void *ctx)
{
  SshOldFSM fsm = (SshOldFSM)ctx;

  if (fsm->num_threads > 0)
    {
      ssh_fatal("Tried to destroy a FSM that has %d thread(s) left",
                fsm->num_threads);
    }

  if (fsm->destructor != NULL)
    (*(fsm->destructor))(fsm->idata);

  while (fsm->conditions != NULL)
    ssh_oldfsm_condition_destroy(fsm->conditions);

  clear_hash_table(fsm);

  ssh_free(fsm->idata);
  ssh_free(fsm);

  SSH_DEBUG(8, ("FSM context destroyed."));
}

void ssh_oldfsm_destroy(SshOldFSM fsm)
{
  ssh_register_timeout(0L, 0L, destroy_callback, (void *)fsm);
}

void *ssh_oldfsm_get_gdata(SshOldFSMThread thread)
{
  return thread->fsm->idata;
}

void *ssh_oldfsm_get_gdata_fsm(SshOldFSM fsm)
{
  return fsm->idata;
}

void ssh_oldfsm_set_next(SshOldFSMThread thread, char *next_state)
{
  thread->current_state = map_state(thread->fsm, next_state);
}

void ssh_oldfsm_push_tdata(SshOldFSMThread thread, void *tdata)
{
  SSH_ASSERT(thread->tdata_stack_size < (SSH_OLDFSM_STACK_SIZE - 1));
  thread->tdata_stack[thread->tdata_stack_size++] = tdata;
}

void *ssh_oldfsm_pop_tdata(SshOldFSMThread thread)
{
  SSH_ASSERT(thread->tdata_stack_size > 1);
  return thread->tdata_stack[--thread->tdata_stack_size];
}
void *ssh_oldfsm_get_tdata(SshOldFSMThread thread)
{
  SSH_ASSERT(thread->tdata_stack_size > 0);
  return thread->tdata_stack[thread->tdata_stack_size - 1];
}

void ssh_oldfsm_push_ehandler(SshOldFSMThread thread,
                           SshOldFSMMessageHandler ehandler)
{
  SSH_ASSERT(thread->ehandler_stack_size < (SSH_OLDFSM_STACK_SIZE - 1));
  thread->ehandler_stack[thread->ehandler_stack_size++] = ehandler;
}

SshOldFSMMessageHandler ssh_oldfsm_pop_ehandler(SshOldFSMThread thread)
{
  SSH_ASSERT(thread->ehandler_stack_size > 1);
  return thread->ehandler_stack[--thread->ehandler_stack_size];
}

SshOldFSM ssh_oldfsm_get_fsm(SshOldFSMThread thread)
{
  return thread->fsm;
}

SshOldFSMCondition ssh_oldfsm_condition_create(SshOldFSM fsm)
{
  SshOldFSMCondition condition;

  if ((condition = ssh_malloc(sizeof(*condition))) != NULL)
    {
      RING_ADD(&(fsm->conditions), condition);
      condition->fsm = fsm;
      condition->waiting = NULL;
    }
  return condition;
}

void ssh_oldfsm_condition_destroy(SshOldFSMCondition condition)
{
  SSH_ASSERT(condition->waiting == NULL);
  RING_REMOVE(&(condition->fsm->conditions), condition);
  ssh_free(condition);
}

void ssh_oldfsm_condition_wait(SshOldFSMThread thread,
                            SshOldFSMCondition condition)
{
  /* A thread can start to wait a condition only when it is running. */
  SSH_ASSERT(thread->flags & SSH_OLDFSM_RUNNING);
  SSH_ASSERT(thread->status == SSH_OLDFSM_T_ACTIVE);
  RING_ADD(&(condition->waiting), thread);
  thread->status = SSH_OLDFSM_T_WAITING_CONDITION;
  thread->waited_condition = condition;

#ifdef SSH_OLDFSM_SCHEDULER_DEBUG
  fprintf(stderr, "On the waiting list: ");
  ssh_oldfsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");
#endif
}

void ssh_oldfsm_condition_signal(SshOldFSMThread thread,
                              SshOldFSMCondition condition)
{
  /* Signalling is not allowed outside the execution of the state
     machine. */
  SSH_ASSERT(thread->fsm->flags & SSH_OLDFSM_IN_SCHEDULER);
  SSH_DEBUG(8, ("Signalling a condition variable."));

  if (condition->waiting == NULL)
    {
      SSH_DEBUG(8, ("Waiting list empty."));
      return;
    }

#ifdef SSH_OLDFSM_SCHEDULER_DEBUG
  ssh_oldfsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");
#endif

  SSH_ASSERT(condition->waiting->status == SSH_OLDFSM_T_WAITING_CONDITION);

  SSH_DEBUG(8, ("Ok, activating one of the waiting threads."));

  condition->waiting->status = SSH_OLDFSM_T_ACTIVE;

  move_thread(&(condition->waiting), &(thread->fsm->active),
              condition->waiting);

#ifdef SSH_OLDFSM_SCHEDULER_DEBUG
  fprintf(stderr, "On the waiting list: ");
  ssh_oldfsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");

  fprintf(stderr, "Active: ");
  ssh_oldfsm_thread_ring_dump(thread->fsm->active, stderr);
  fprintf(stderr, "\n");
#endif
}

void ssh_oldfsm_condition_broadcast(SshOldFSMThread thread,
                                 SshOldFSMCondition condition)
{
  while (condition->waiting != NULL)
    ssh_oldfsm_condition_signal(thread, condition);
}

/* Kill thread immediately. */
void ssh_oldfsm_kill_thread(SshOldFSMThread thread)
{
  SSH_ASSERT(!(thread->flags & SSH_OLDFSM_RUNNING));

  /* Remove the thread from the appropriate ring. */
  switch (thread->status)
    {
    case SSH_OLDFSM_T_ACTIVE:
      RING_REMOVE(&(thread->fsm->active), thread);
      break;

    case SSH_OLDFSM_T_SUSPENDED:
      RING_REMOVE(&(thread->fsm->waiting_external), thread);
      break;

    case SSH_OLDFSM_T_WAITING_CONDITION:
      RING_REMOVE(&(thread->waited_condition->waiting), thread);
      break;

    default:
      SSH_NOTREACHED;
    }

  delete_thread(thread);
}

void ssh_oldfsm_throw(SshOldFSMThread thread,
                   SshOldFSMThread recipient,
                   SshUInt32 message)
{
  SshOldFSMMessageBlob *blob;

  if ((blob = ssh_malloc(sizeof(*blob))) != NULL)
    {
      /* Add the message as the last message in the ring. */
      RING_ADD(&(thread->fsm->messages), blob);
      RING_ROTATE(&(thread->fsm->messages));

      blob->recipient = recipient;
      blob->message = message;

      schedule_scheduler(thread->fsm);
    }
}

void ssh_oldfsm_set_thread_name(SshOldFSMThread thread, const char *name)
{
  ssh_free(thread->name);
  thread->name = ssh_strdup(name);
}

const char *ssh_oldfsm_get_thread_name(SshOldFSMThread thread)
{
  return thread->name;
}

const char *ssh_oldfsm_get_thread_current_state(SshOldFSMThread thread)
{
  return thread->current_state->state_id;
}

void ssh_oldfsm_set_callback_flag(SshOldFSMThread thread)
{
  thread->flags |= SSH_OLDFSM_CALLBACK_FLAG;
}

void ssh_oldfsm_drop_callback_flag(SshOldFSMThread thread)
{
  thread->flags &= ~SSH_OLDFSM_CALLBACK_FLAG;
}

Boolean ssh_oldfsm_get_callback_flag(SshOldFSMThread thread)
{
  return ((thread->flags & SSH_OLDFSM_CALLBACK_FLAG) != 0);
}
