/*
 *
 * t-fsm2.c
 *
 * Author: Markku Rossi <mtr@ssh.fi>
 *
 * Copyright (c) 2001 SSH Communications Security, Finland
 *               All rights reserved.
 *
 * Regression test for FSM2.
 *
 */

#include "sshincludes.h"
#include "ssheloop.h"
#include "sshtimeouts.h"
#include "sshfsm.h"

#define SSH_DEBUG_MODULE "t-fsm2"

/********************** Prototypes for state functions **********************/

SSH_FSM_STEP(loop);
SSH_FSM_STEP(main_start);
SSH_FSM_STEP(main_loop_done);

SSH_FSM_STEP(waiter);
SSH_FSM_STEP(waiter_done);
SSH_FSM_STEP(main_wait_many);
SSH_FSM_STEP(main_wait_many_wait);

SSH_FSM_STEP(wait_cond);
SSH_FSM_STEP(main_condition_test);
SSH_FSM_STEP(main_condition_test_signal);
SSH_FSM_STEP(main_condition_test_signal_done);
SSH_FSM_STEP(main_condition_test_broadcast_done);
SSH_FSM_STEP(main_condition_test_destroy_done);

SSH_FSM_STEP(kenny);
SSH_FSM_STEP(main_kill_thread);
SSH_FSM_STEP(main_kill_thread_do_kill);

SSH_FSM_STEP(main_async_call);
SSH_FSM_STEP(main_async_call_done);
SSH_FSM_STEP(main_async_sync_call_done);

SSH_FSM_STEP(wait_msg);
SSH_FSM_STEP(wait_msg_done);
SSH_FSM_STEP(main_msg);
SSH_FSM_STEP(main_msg_send);
SSH_FSM_STEP(main_msg_wait_done);

/* Testing FSM debugging.  You do not have to create the state array
   unless you want to get FSM level debugging from the state machine.
   To demonstrate this, we initalize just three states although our
   test case contains many more. */
static SshFSMStateDebugStruct state_array[] =
{
  SSH_FSM_STATE("loop", "Loop counter", loop)
  SSH_FSM_STATE("main-start", "Main start", main_start)
  SSH_FSM_STATE("main-loop-done", "Main loop done", main_loop_done)
};

static int num_states = SSH_FSM_NUM_STATES(state_array);


/******************************** Test cases ********************************/

int errors = 0;

int done;
int rounds;

SshFSMThreadStruct thread1;

#define NUM_THREADS 100

int num_threads = 0;
SshFSMThread threads[NUM_THREADS];

SshFSMConditionStruct cond;


/******************** Waiting for a thread to terminate *********************/

SSH_FSM_STEP(loop)
{
#ifdef DEBUG_LIGHT
  char *name = (char *) thread_context;
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_LOWOK, ("%s: rounds=%d", name, rounds));

  if (--rounds <= 0)
    return SSH_FSM_FINISH;

  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_start)
{
  rounds = 100;
  ssh_fsm_thread_init(fsm, &thread1, loop, NULL, NULL, "loop");

  SSH_FSM_SET_NEXT(main_loop_done);
  SSH_FSM_WAIT_THREAD(&thread1);
}

SSH_FSM_STEP(main_loop_done)
{
  if (rounds != 0)
    errors++;

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Thread wait works"));

  SSH_FSM_SET_NEXT(main_wait_many);

  return SSH_FSM_CONTINUE;
}

/**** Waiting for multiple threads which wait for a thread to terminate. ****/

SSH_FSM_STEP(waiter)
{
#ifdef DEBUG_LIGHT
  int idx = (int) thread_context;
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_LOWSTART, ("Waiter %d starting", idx));

  SSH_FSM_SET_NEXT(waiter_done);
  SSH_FSM_WAIT_THREAD(&thread1);
}

SSH_FSM_STEP(waiter_done)
{
  int idx = (int) thread_context;

  SSH_DEBUG(SSH_D_LOWOK, ("Waiter %d done", idx));

  threads[idx] = NULL;
  num_threads--;
  SSH_FSM_CONDITION_SIGNAL(&cond);

  return SSH_FSM_FINISH;
}

SSH_FSM_STEP(main_wait_many)
{
  /* Create the loop thread. */
  rounds = 100;
  ssh_fsm_thread_init(fsm, &thread1, loop, NULL, NULL, "loop");

  /* Create some waiters. */
  for (num_threads = 0; num_threads < NUM_THREADS; num_threads++)
    {
      threads[num_threads] = ssh_fsm_thread_create(fsm, waiter, NULL, NULL,
                                                   (void *) num_threads);
      SSH_ASSERT(threads[num_threads]);
    }

  /* And wait that the waiters are done. */
  SSH_FSM_SET_NEXT(main_wait_many_wait);
  return SSH_FSM_CONTINUE;
}

SSH_FSM_STEP(main_wait_many_wait)
{
  int i;

  if (num_threads > 0)
    SSH_FSM_CONDITION_WAIT(&cond);

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiters done."));

  for (i = 0; i < NUM_THREADS; i++)
    if (threads[i] != NULL)
      {
        SSH_DEBUG(SSH_D_ERROR, ("Waiter %d not exited correctly", i));
        errors++;
      }

  SSH_FSM_SET_NEXT(main_condition_test);
  return SSH_FSM_CONTINUE;
}


/*************************** Condition variables ****************************/

SSH_FSM_STEP(wait_cond)
{
  int idx = (int) thread_context;

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiter %d", idx));
  rounds++;

  if (!done)
    SSH_FSM_CONDITION_WAIT(&cond);

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiter %d done", idx));
  threads[idx] = NULL;

  return SSH_FSM_FINISH;
}

SSH_FSM_STEP(main_condition_test)
{
  done = 0;
  rounds = 0;

  /* Create waiters for condition. */
  SSH_DEBUG(SSH_D_NICETOKNOW, ("Starting waiters"));
  for (num_threads = 0; num_threads < NUM_THREADS; num_threads++)
    {
      threads[num_threads] = ssh_fsm_thread_create(fsm, wait_cond, NULL, NULL,
                                                   (void *) num_threads);
      SSH_ASSERT(threads[num_threads]);
    }

  SSH_FSM_SET_NEXT(main_condition_test_signal);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_condition_test_signal)
{
  if (rounds != NUM_THREADS)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Condition waiters have not started yet"));
      errors++;
    }
  rounds = 0;

  /* Wake up one. */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Signalling condition"));
  SSH_FSM_CONDITION_SIGNAL(&cond);

  SSH_FSM_SET_NEXT(main_condition_test_signal_done);

  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_condition_test_signal_done)
{
  if (rounds != 1)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Signal woke up more that one thread"));
      errors++;
    }
  rounds = 0;

  /* Wake them all. */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Broadcasting condition"));
  SSH_FSM_CONDITION_BROADCAST(&cond);

  SSH_FSM_SET_NEXT(main_condition_test_broadcast_done);

  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_condition_test_broadcast_done)
{
  if (rounds != NUM_THREADS)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Broadcast did not wake all threads"));
      errors++;
    }
  rounds = 0;
  done = 1;

  /* Destroy them all. */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Destroying waiters"));
  SSH_FSM_CONDITION_BROADCAST(&cond);

  SSH_FSM_SET_NEXT(main_condition_test_destroy_done);

  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_condition_test_destroy_done)
{
  int i;

  if (rounds != NUM_THREADS)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Broadcast did not wake all threads"));
      errors++;
    }

  for (i = 0; i < NUM_THREADS; i++)
    if (threads[i] != NULL)
      {
        SSH_DEBUG(SSH_D_ERROR, ("Waiter %d not exited correctly", i));
        errors++;
      }

  SSH_FSM_SET_NEXT(main_kill_thread);
  return SSH_FSM_CONTINUE;
}


/****************************** Killing thread ******************************/

SSH_FSM_STEP(kenny)
{
#ifdef DEBUG_LIGHT
  char *name = (char *) thread_context;
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("%s: suspending", name));
  return SSH_FSM_SUSPENDED;
}

static void
kenny_destructor(void *context)
{
#ifdef DEBUG_LIGHT
  char *name = (char *) context;
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Oh my God, they killed %s!", name));
  rounds++;
}

SSH_FSM_STEP(main_kill_thread)
{
  rounds = 0;
  ssh_fsm_thread_init(fsm, &thread1, kenny, NULL, kenny_destructor, "Kenny");

  SSH_FSM_SET_NEXT(main_kill_thread_do_kill);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_kill_thread_do_kill)
{
  ssh_fsm_kill_thread(&thread1);

  if (rounds != 1)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Killing thread did not call destructor"));
      errors++;
    }

  SSH_FSM_SET_NEXT(main_async_call);
  return SSH_FSM_YIELD;
}


/**************************** Asynchronous calls ****************************/

static void
timeout_cb(void *context)
{
  SshFSMThread thread = context;

  SSH_DEBUG(SSH_D_NICETOKNOW, ("In timeout callback."));

  done = 1;
  SSH_FSM_CONTINUE_AFTER_CALLBACK(thread);
}

SSH_FSM_STEP(main_async_call)
{
  done = 0;
  SSH_FSM_SET_NEXT(main_async_call_done);
  SSH_FSM_ASYNC_CALL(ssh_register_timeout(0, 500000, timeout_cb, thread));
}

SSH_FSM_STEP(main_async_call_done)
{
  if (!done)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Asynchronous call did not set done"));
      errors++;
    }

  done = 0;
  SSH_FSM_SET_NEXT(main_async_sync_call_done);
  SSH_FSM_ASYNC_CALL(timeout_cb(thread));
}

SSH_FSM_STEP(main_async_sync_call_done)
{
  if (!done)
    {
      SSH_DEBUG(SSH_D_ERROR,
                ("Synchronous asynchronous call did not set done"));
      errors++;
    }

  SSH_FSM_SET_NEXT(main_msg);
  return SSH_FSM_CONTINUE;
}


/********************************* Messages *********************************/

SSH_FSM_STEP(wait_msg)
{
#ifdef DEBUG_LIGHT
  int idx = (int) thread_context;
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiter %d started", idx));
  rounds++;

  return SSH_FSM_SUSPENDED;
}

SSH_FSM_STEP(wait_msg_done)
{
  int idx = (int) thread_context;

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiter %d dying"));

  threads[idx] = NULL;
  num_threads--;
  rounds++;

  SSH_FSM_CONDITION_SIGNAL(&cond);

  return SSH_FSM_FINISH;
}

static void
message_handler(SshFSMThread thread, SshUInt32 message)
{
#ifdef DEBUG_LIGHT
  int idx = (int) ssh_fsm_get_tdata(thread);
#endif /* DEBUG_LIGHT */

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Waiter %d: got message %d",
                               idx, message));

  SSH_FSM_SET_NEXT(wait_msg_done);
  ssh_fsm_continue(thread);
}

SSH_FSM_STEP(main_msg)
{
  rounds = 0;

  /* Create waiters */
  SSH_DEBUG(SSH_D_NICETOKNOW, ("Creating message waiters"));
  for (num_threads = 0; num_threads < NUM_THREADS; num_threads++)
    {
      threads[num_threads] = ssh_fsm_thread_create(fsm, wait_msg,
                                                   message_handler, NULL,
                                                   (void *) num_threads);
      SSH_ASSERT(threads[num_threads]);
    }

  SSH_FSM_SET_NEXT(main_msg_send);
  return SSH_FSM_YIELD;
}

SSH_FSM_STEP(main_msg_send)
{
  int i;

  if (rounds != NUM_THREADS)
    {
      SSH_DEBUG(SSH_D_ERROR, ("All waiters have not started"));
      errors++;
    }
  rounds = 0;

  SSH_DEBUG(SSH_D_NICETOKNOW, ("Sending messages"));
  for (i = 0; i < num_threads; i++)
    SSH_FSM_THROW(threads[i], 42);

  SSH_FSM_SET_NEXT(main_msg_wait_done);
  return SSH_FSM_CONTINUE;
}

SSH_FSM_STEP(main_msg_wait_done)
{
  if (num_threads)
    SSH_FSM_CONDITION_WAIT(&cond);

  if (rounds != NUM_THREADS)
    {
      SSH_DEBUG(SSH_D_ERROR, ("Not all threds died"));
      errors++;
    }

  return SSH_FSM_FINISH;
}

/*********************************** Main ***********************************/

int
main(int argc, char *argv[])
{
  SshFSM fsm;

  if (argc == 2)
    ssh_debug_set_level_string(argv[1]);

  ssh_event_loop_initialize();

  fsm = ssh_fsm_create(NULL);
  SSH_ASSERT(fsm);

  ssh_fsm_register_debug_names(fsm, state_array, num_states);

  /* Create a condition variable. */
  ssh_fsm_condition_init(fsm, &cond);

  ssh_fsm_thread_create(fsm, main_start, NULL, NULL, "main");

  ssh_event_loop_run();

  ssh_fsm_destroy(fsm);

  ssh_event_loop_uninitialize();

  return errors;
}
