/*

  t-threads2.c

  Author: Santeri Paavolainen <santtu@ssh.com>

  Copyright (c) 2001 SSH Communications Security Oy <info@ssh.fi>

*/

#include "sshincludes.h"
#include "sshthread.h"
#include "sshthreadpool.h"
#include "sshcondition.h"
#include "sshregression.h"

#define SSH_DEBUG_MODULE "SshThreadTest"

#ifdef HAVE_THREADS

static SshCondition condition;
static SshMutex mutex;

#define NTHREADS 64

#define INIT_N_THREADS(N, ARY, FUNC, ARG)                       \
do {                                                            \
  int __index;                                                  \
  for (__index = 0; __index < (N); __index++)                   \
    if (!((ARY)[__index] = ssh_thread_create((FUNC), (ARG))))   \
      return FALSE;                                             \
} while (0)

/************************************************************************/
static void* thread_1_1(void *ptr)
{
  return ptr;
}

Boolean test_1_1 ()
{
  SshThread threads[NTHREADS];
  int table[NTHREADS];
  int *ptr, i;

  INIT_N_THREADS(NTHREADS, threads, thread_1_1, &table[__index]);

  for (i = NTHREADS - 1; i >= 0; i--)
    {
      ptr = ssh_thread_join(threads[i]);

      if (ptr != &table[i])
        return FALSE;
    }

  return TRUE;
}

void* thread_1_2 (void *ptr)
{
  Boolean *bptr = (Boolean *) ptr;
  *bptr = TRUE;

  return NULL;
}

Boolean test_1_2 ()
{
  SshThreadPool pool;
  int counter, i;
  SshThreadPoolParamsStruct params = { 0, NTHREADS / 8 };
  Boolean acks[NTHREADS];

  if (!(pool = ssh_thread_pool_create(&params)))
    return FALSE;

  counter = 0;

  for (i = 0; i < NTHREADS; i++)
    {
      acks[i] = FALSE;
      ssh_thread_pool_start(pool, TRUE, thread_1_2, &acks[i]);
    }

  /* mutex test comes later, now avoid mutex & condvars */
  while (1)
    {
      for (i = 0; i < NTHREADS; i++)
        if (!acks[i])
          break;

      if (i == NTHREADS)
        break;

      ssh_sleep(0, 1000);
    }

  ssh_thread_pool_destroy(pool);

  return TRUE;
}

/************************************************************************/

static void* thread_2_1 (void *ptr)
{
  int *iptr = (int *) ptr;

  ssh_thread_detach(ssh_thread_current());

  ssh_mutex_lock(mutex);
  (*iptr)++;
  ssh_mutex_unlock(mutex);

  /* we notify our termination to the waiting thread */
  ssh_condition_signal(condition);

  return NULL;
}

Boolean test_2_1 ()
{
  int counter;
  SshThread threads[NTHREADS];

  ssh_mutex_lock(mutex);

  counter = 0;

  INIT_N_THREADS(NTHREADS, threads, thread_2_1, &counter);

  /* Now, release mutex and broadcast on the cond var, wait until all
     have been accounted for */
  ssh_mutex_unlock(mutex);

  ssh_condition_broadcast(condition);

  ssh_mutex_lock(mutex);

  while (counter != NTHREADS)
    ssh_condition_wait(condition, mutex);

  ssh_mutex_unlock(mutex);

  return TRUE;
}

#define MAXROUNDS 0x4000

void *test_2_2_a(void *ptr)
{
  int *iptr = (int *) ptr;

  ssh_mutex_lock(mutex);
  while (1)
    {
      /* even */
      while ((*(iptr) % 2) != 0 && *iptr != MAXROUNDS)
        ssh_condition_wait(condition, mutex);

      SSH_DEBUG(15, ("even iteration %d", *iptr));

      if (*iptr == MAXROUNDS)
        break;

      (*iptr)++;

      ssh_condition_signal(condition);
    }

  ssh_condition_signal(condition);
  ssh_mutex_unlock(mutex);

  return NULL;
}

void *test_2_2_b(void *ptr)
{
  int *iptr = (int *) ptr;

  ssh_mutex_lock(mutex);
  while (1)
    {
      /* odd */
      while ((*(iptr) % 2) != 1 && *iptr != MAXROUNDS)
        ssh_condition_wait(condition, mutex);

      SSH_DEBUG(15, ("odd iteration %d", *iptr));

      if (*iptr == MAXROUNDS)
        break;

      (*iptr)++;

      ssh_condition_signal(condition);
    }

  ssh_condition_signal(condition);
  ssh_mutex_unlock(mutex);

  return NULL;
}

Boolean test_2_2 ()
{
  int counter = 0;
  SshThread a, b;

  if (!(a = ssh_thread_create(test_2_2_a, &counter)) ||
      !(b = ssh_thread_create(test_2_2_b, &counter)))
    return FALSE;

  ssh_thread_join(a);
  ssh_thread_join(b);

  if (counter != MAXROUNDS)
    return FALSE;

  return TRUE;
}

SshCondition cond_2_3_a, cond_2_3_b;
Boolean rdy_2_3_a, rdy_2_3_b;
SshUInt32 iter_2_3;

void *test_2_3_a (void *ptr)
{
  ssh_mutex_lock(mutex);
  rdy_2_3_a = TRUE;
  ssh_condition_signal(cond_2_3_a);
  ssh_condition_wait(condition, mutex);

  while (1)
    {
      /* even thread */
      if ((iter_2_3 % 2) != 0)
        return (void*) FALSE;

      iter_2_3++;

      ssh_condition_signal(condition);

      if (iter_2_3 >= MAXROUNDS)
        {
          ssh_mutex_unlock(mutex);
          return (void*) TRUE;
        }

      /* the other thread cannot be executin now, since we hold the
         mutex, thus wait */
      ssh_condition_wait(condition, mutex);
    }
}

void *test_2_3_b (void *ptr)
{
  ssh_mutex_lock(mutex);
  rdy_2_3_b = TRUE;
  ssh_condition_signal(cond_2_3_b);
  ssh_condition_wait(condition, mutex);

  /* notice: since we're odd thread, we must wait for even thread to
     proceed first */
  while ((iter_2_3 % 2) == 0)
    ssh_condition_wait(condition, mutex);

  while (1)
    {
      /* odd thread */
      if ((iter_2_3 % 2) != 1)
        return (void*) FALSE;

      iter_2_3++;

      ssh_condition_signal(condition);

      if (iter_2_3 >= MAXROUNDS)
        {
          ssh_mutex_unlock(mutex);
          return (void*) TRUE;
        }

      /* now, wait for the other */
      ssh_condition_wait(condition, mutex);
    }
}

Boolean test_2_3 ()
{
  SshThread a, b;

  ssh_mutex_lock(mutex);
  iter_2_3 = 0;
  rdy_2_3_a = rdy_2_3_b = FALSE;
  cond_2_3_a = ssh_condition_create(NULL, 0);
  cond_2_3_b = ssh_condition_create(NULL, 0);
  a = ssh_thread_create(test_2_3_a, NULL);
  b = ssh_thread_create(test_2_3_b, NULL);

  if (!cond_2_3_a || !cond_2_3_b || !a || !b)
    return FALSE;

  while (!rdy_2_3_a)
    ssh_condition_wait(cond_2_3_a, mutex);
  while (!rdy_2_3_b)
    ssh_condition_wait(cond_2_3_b, mutex);

  /* after both threads have synchronized with us, let them proceed */
  ssh_mutex_unlock(mutex);
  ssh_condition_broadcast(condition);

  if (ssh_thread_join(a) != (void*)TRUE ||
      ssh_thread_join(b) != (void*)TRUE)
    return FALSE;

  ssh_condition_destroy(cond_2_3_a);
  ssh_condition_destroy(cond_2_3_b);
  return TRUE;
}

/************************************************************************/
int main(int ac, char *av[])
{
  ssh_regression_init(&ac, &av, "Threads", "santtu@ssh.com");

  condition = ssh_condition_create(NULL, 0);
  mutex = ssh_mutex_create(NULL, 0);

  if (!condition || !mutex)
    exit(1);

  ssh_regression_section("Simple thread creation and termination");

  /* First test: simply create threads, and wait for them to return */
  SSH_REGRESSION_TEST("Creation & join", test_1_1, ());
  SSH_REGRESSION_TEST("Thread pool", test_1_2, ());

  ssh_regression_section("Mutexes and condition variables");

  /* Second test: create threads, wait on barrier, then proceed */
  SSH_REGRESSION_TEST("Condition variable synchronization", test_2_1, ());
  SSH_REGRESSION_TEST("Ping-pong synchronization", test_2_2, ());
  SSH_REGRESSION_TEST("Freerunning synchronization", test_2_3, ());

  ssh_regression_finish();

  /*NOTREACHED*/
  exit(1);
}
#else /* !HAVE_THREADS */
int main(int ac, char *av[])
{
  exit(0);
}
#endif /* HAVE_THREADS */
