/*

  ssholdfsmstreams.c

  Author: Antti Huima <huima@ssh.fi>

  Copyright (c) 1999-2001 SSH Communications Security, Finland
  All rights reserved.

  Created Thu Sep  2 15:21:26 1999.

  */

#include "sshincludes.h"

#include "sshdebug.h"
#include "ssholdfsmstreams.h"
#include "ssholdfsm.h"
#include "sshstream.h"
#include "sshbuffer.h"

#define SSH_DEBUG_MODULE "SshStreamstub"

#define SSH_OLDSTREAMSTUB_EXPECT_READ_NOTIFY       0x0001
#define SSH_OLDSTREAMSTUB_EXPECT_WRITE_NOTIFY      0x0002
#define SSH_OLDSTREAMSTUB_SENDING_EOF              0x0004
#define SSH_OLDSTREAMSTUB_EOF_SENT                 0x0008
#define SSH_OLDSTREAMSTUB_DRAINING                 0x0010

#define SSH_OLDSTREAMSTUB_READER_DIED              0xf000
#define SSH_OLDSTREAMSTUB_WRITER_DIED              0xf001

typedef struct {
  SshOldFSMThread reader;
  SshOldFSMThread writer;
  SshStream stream;
  SshUInt32 *flags;
  SshOldFSMCondition finished_condition;
} SshFSMStreamParentDataRec;

typedef struct {
  SshOldFSMThread parent;

  SshStream stream;             /* The stream to access. */

  SshBuffer in_buf;            /* Buffer where data will be read to. */

  SshUInt32 in_buf_limit;       /* `in_buf' won't grow below this limit. */

  SshUInt32 *flags;             /* shared flags */

  SshOldFSMCondition got_more;     /* Condition the stub will signal when
                                   more data has been got in `in_buf'. */
  SshOldFSMCondition in_buf_shrunk;
                                /* Condition that outside must signal when
                                   `in_buf' has shrunk. */

  SshUInt32 own_flags;          /* Own flags. */
} SshFSMStreamReaderDataRec;

typedef struct {
  SshOldFSMThread parent;

  SshStream stream;             /* The stream to access. */

  SshBuffer out_buf;           /* Buffer where data will be written from. */

  SshUInt32 *flags;             /* shared flags */

  SshOldFSMCondition out_buf_shrunk;
                                /* Condition the stub will signal when
                                   data has been consumed from `out_buf'. */
  SshOldFSMCondition data_present;
                                /* Condition that outside must signal when
                                   `out_buf' has got some [more] data. */
  SshUInt32 own_flags;          /* Own flags. */
} SshFSMStreamWriterDataRec;

/* The reader thread. */

static void reader_message_handler(SshOldFSMThread thread, SshUInt32 message)
{
  switch (message)
    {
      case SSH_OLDSTREAMSTUB_ABORT:
        ssh_oldfsm_set_next(thread, "util/streamstub/abort-reader");
        ssh_oldfsm_continue(thread);
        break;

    default:
      SSH_NOTREACHED;
    }
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_abort_reader)
{
  SSH_OLDFSM_TDATA(SshFSMStreamReaderDataRec *);
  SSH_OLDFSM_THROW(tdata->parent, SSH_OLDSTREAMSTUB_READER_DIED);
  return SSH_OLDFSM_FINISH;
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_read)
{
  SSH_OLDFSM_TDATA(SshFSMStreamReaderDataRec *);
  int result;
  unsigned char *ptr;
  size_t room;

  SSH_ASSERT(tdata->stream != NULL);
  SSH_ASSERT(tdata->in_buf != NULL);

  room = tdata->in_buf_limit - ssh_buffer_len(tdata->in_buf);

  SSH_ASSERT(room > 0);

  ssh_buffer_append_space(tdata->in_buf, &ptr, room);

  result = ssh_stream_read(tdata->stream, ptr, room);

  if (result < 0)
    {
      SSH_DEBUG(8, ("Read blocks."));
      ssh_buffer_consume_end(tdata->in_buf, room);
      /* Blocking. */
      tdata->own_flags |= SSH_OLDSTREAMSTUB_EXPECT_READ_NOTIFY;
      return SSH_OLDFSM_SUSPENDED;
    }

  if (result == 0)
    {
      SSH_DEBUG(8, ("Read returned EOF."));
      /* EOF got. */
      ssh_buffer_consume_end(tdata->in_buf, room);
      *(tdata->flags) |= SSH_OLDSTREAMSTUB_EOF_RECEIVED;
      if (tdata->got_more != NULL)
        SSH_OLDFSM_CONDITION_SIGNAL(tdata->got_more);
      SSH_OLDFSM_SET_NEXT("util/streamstub/abort-reader");
      return SSH_OLDFSM_CONTINUE;
    }

  /* result > 0 */

  if (room > result)
    {
      SSH_DEBUG(8, ("Read in %d bytes, continuing.", result));
      ssh_buffer_consume_end(tdata->in_buf, room - result);
      if (tdata->got_more != NULL)
        SSH_OLDFSM_CONDITION_SIGNAL(tdata->got_more);
      return SSH_OLDFSM_CONTINUE;
    }
  else
    {
      SSH_DEBUG(8, ("Read in %d bytes, buffer full.", result));
      if (tdata->got_more != NULL)
        SSH_OLDFSM_CONDITION_SIGNAL(tdata->got_more);
      SSH_OLDFSM_CONDITION_WAIT(tdata->in_buf_shrunk);
    }
}

/* The writer thread. */

static void writer_message_handler(SshOldFSMThread thread, SshUInt32 message)
{
  SSH_OLDFSM_TDATA(SshFSMStreamWriterDataRec *);

  switch (message)
    {
    case SSH_OLDSTREAMSTUB_ABORT:
      ssh_oldfsm_set_next(thread, "util/streamstub/abort-writer");
      ssh_oldfsm_continue(thread);
      break;

    case SSH_OLDSTREAMSTUB_SEND_EOF:
    case SSH_OLDSTREAMSTUB_FINISH:
      tdata->own_flags |= SSH_OLDSTREAMSTUB_SENDING_EOF;
      ssh_oldfsm_continue(thread);
      break;

    default:
      SSH_NOTREACHED;
    }
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_abort_writer)
{
  SSH_OLDFSM_TDATA(SshFSMStreamWriterDataRec *);
  SSH_OLDFSM_THROW(tdata->parent, SSH_OLDSTREAMSTUB_WRITER_DIED);
  return SSH_OLDFSM_FINISH;
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_write)
{
  SSH_OLDFSM_TDATA(SshFSMStreamWriterDataRec *);
  int result;
  unsigned char *ptr;

  SSH_ASSERT(tdata->stream != NULL);
  SSH_ASSERT(tdata->out_buf != NULL);

  ptr = ssh_buffer_ptr(tdata->out_buf);

  if (ssh_buffer_len(tdata->out_buf) == 0 &&
      ((tdata->own_flags) & SSH_OLDSTREAMSTUB_SENDING_EOF))
    {
      SSH_ASSERT(!(tdata->own_flags & SSH_OLDSTREAMSTUB_EOF_SENT));
      SSH_DEBUG(8, ("Sending eof."));
      tdata->own_flags |= SSH_OLDSTREAMSTUB_EOF_SENT;
      ssh_stream_output_eof(tdata->stream);
      SSH_OLDFSM_SET_NEXT("util/streamstub/abort-writer");
      return SSH_OLDFSM_CONTINUE;
    }

  if (ssh_buffer_len(tdata->out_buf) == 0)
    {
      SSH_DEBUG(8, ("Nothing to write."));
      SSH_OLDFSM_CONDITION_WAIT(tdata->data_present);
    }

  result = ssh_stream_write(tdata->stream, ptr,
                            ssh_buffer_len(tdata->out_buf));

  if (result < 0)
    {
      /* Blocking. */
      SSH_DEBUG(8, ("Write blocks."));
      tdata->own_flags |= SSH_OLDSTREAMSTUB_EXPECT_WRITE_NOTIFY;
      return SSH_OLDFSM_SUSPENDED;
    }

  if (result == 0)
    {
      /* EOF got. */
      SSH_DEBUG(8, ("Write fails."));
      *(tdata->flags) |= SSH_OLDSTREAMSTUB_OUTPUT_CLOSED;
      if (tdata->out_buf_shrunk != NULL)
        SSH_OLDFSM_CONDITION_SIGNAL(tdata->out_buf_shrunk);
      SSH_OLDFSM_SET_NEXT("util/streamstub/abort-writer");
      return SSH_OLDFSM_CONTINUE;
    }

  /* result > 0 */
  SSH_DEBUG(8, ("Wrote %d bytes, continuing.", result));
  ssh_buffer_consume(tdata->out_buf, result);
  if (tdata->out_buf_shrunk != NULL)
    SSH_OLDFSM_CONDITION_SIGNAL(tdata->out_buf_shrunk);
  return SSH_OLDFSM_CONTINUE;
}

/* The parent thread. */

static void parent_message_handler(SshOldFSMThread thread, SshUInt32 message)
{
  SSH_OLDFSM_TDATA(SshFSMStreamParentDataRec *);

  switch (message)
    {
      case SSH_OLDSTREAMSTUB_SEND_EOF:
        ssh_oldfsm_set_next(thread, "util/streamstub/send-eof");
        ssh_oldfsm_continue(thread);
        return;

      case SSH_OLDSTREAMSTUB_FINISH:
        ssh_oldfsm_set_next(thread, "util/streamstub/finish");
        ssh_oldfsm_continue(thread);
        return;

      case SSH_OLDSTREAMSTUB_ABORT:
        ssh_oldfsm_set_next(thread, "util/streamstub/abort");
        ssh_oldfsm_continue(thread);
        return;

      case SSH_OLDSTREAMSTUB_READER_DIED:
        tdata->reader = NULL;
        break;

      case SSH_OLDSTREAMSTUB_WRITER_DIED:
        tdata->writer = NULL;
        break;

    default:
      SSH_NOTREACHED;
    }

  /* When both children are dead, prepare to die. */
  if (tdata->reader == NULL && tdata->writer == NULL)
    {
      ssh_oldfsm_set_next(thread, "util/streamstub/die");
      ssh_oldfsm_continue(thread);
    }
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_die)
{
  SSH_OLDFSM_TDATA(SshFSMStreamParentDataRec *);
  *(tdata->flags) |= SSH_OLDSTREAMSTUB_FINISHED;
  if (tdata->finished_condition != NULL)
    SSH_OLDFSM_CONDITION_SIGNAL(tdata->finished_condition);
  ssh_stream_destroy(tdata->stream);
  return SSH_OLDFSM_FINISH;
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_abort)
{
  SSH_OLDFSM_TDATA(SshFSMStreamParentDataRec *);
  if (tdata->writer != NULL)
    SSH_OLDFSM_THROW(tdata->writer, SSH_OLDSTREAMSTUB_ABORT);
  if (tdata->reader != NULL)
    SSH_OLDFSM_THROW(tdata->writer, SSH_OLDSTREAMSTUB_ABORT);
  return SSH_OLDFSM_SUSPENDED;
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_finish)
{
  SSH_OLDFSM_TDATA(SshFSMStreamParentDataRec *);
  if (tdata->writer != NULL)
    SSH_OLDFSM_THROW(tdata->writer, SSH_OLDSTREAMSTUB_FINISH);
  if (tdata->reader != NULL)
    SSH_OLDFSM_THROW(tdata->reader, SSH_OLDSTREAMSTUB_ABORT);
  return SSH_OLDFSM_SUSPENDED;
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_send_eof)
{
  SSH_OLDFSM_TDATA(SshFSMStreamParentDataRec *);
  if (tdata->writer != NULL)
    SSH_OLDFSM_THROW(tdata->writer, SSH_OLDSTREAMSTUB_SEND_EOF);
  return SSH_OLDFSM_SUSPENDED;
}

static void got_read_notify(SshOldFSMThread thread)
{
  SSH_OLDFSM_TDATA(SshFSMStreamReaderDataRec *);
  if (tdata->own_flags & SSH_OLDSTREAMSTUB_EXPECT_READ_NOTIFY)
    {
      tdata->own_flags &= ~SSH_OLDSTREAMSTUB_EXPECT_READ_NOTIFY;
      ssh_oldfsm_continue(thread);
    }
}

static void got_write_notify(SshOldFSMThread thread)
{
  SSH_OLDFSM_TDATA(SshFSMStreamWriterDataRec *);
  if (tdata->own_flags & SSH_OLDSTREAMSTUB_EXPECT_WRITE_NOTIFY)
    {
      tdata->own_flags &= ~SSH_OLDSTREAMSTUB_EXPECT_WRITE_NOTIFY;
      ssh_oldfsm_continue(thread);
    }
}

static void got_disconnect(SshOldFSMThread thread)
{
  ssh_oldfsm_throw(thread, thread, SSH_OLDSTREAMSTUB_ABORT);
}

static void stream_callback(SshStreamNotification notification,
                            void *context)
{
  SshOldFSMThread thread = context;
  SshFSMStreamParentDataRec *d = ssh_oldfsm_get_tdata(thread);

  switch (notification)
    {
    case SSH_STREAM_INPUT_AVAILABLE:
      SSH_DEBUG(8, ("Got input available notification."));
      if (d->reader != NULL)
        got_read_notify(d->reader);
      break;

    case SSH_STREAM_CAN_OUTPUT:
      SSH_DEBUG(8, ("Got can output notification."));
      if (d->writer != NULL)
        got_write_notify(d->writer);
      break;

    case SSH_STREAM_DISCONNECTED:
      SSH_DEBUG(8, ("Got stream disconnected notification."));
      if (d->writer != NULL)
        got_disconnect(d->writer);
      if (d->reader != NULL)
        got_disconnect(d->reader);
      break;
    }
}

SSH_OLDFSM_STEP(ssh_oldstreamstub_parent)
{
  return SSH_OLDFSM_SUSPENDED;
}

SshOldFSMThread ssh_oldstreamstub_spawn(SshOldFSM fsm,
                                  SshStream stream,
                                  SshBuffer in_buf,
                                  SshBuffer out_buf,
                                  SshUInt32 in_buf_limit,
                                  SshOldFSMCondition stub_has_read_more,
                                  SshOldFSMCondition in_buf_has_shrunk,
                                  SshOldFSMCondition stub_has_written_some,
                                  SshOldFSMCondition out_buf_has_grown,
                                  SshOldFSMCondition stub_finished,
                                  SshUInt32 *shared_flags)
{
  SshFSMStreamParentDataRec *p;
  SshFSMStreamReaderDataRec *r;
  SshFSMStreamWriterDataRec *w;
  SshOldFSMThread thread;

  thread = ssh_oldfsm_spawn(fsm, sizeof(SshFSMStreamParentDataRec),
                         "util/streamstub/parent",
                         parent_message_handler, NULL);

  p = ssh_oldfsm_get_tdata(thread);

  p->finished_condition = stub_finished;

  p->reader = ssh_oldfsm_spawn(fsm, sizeof(SshFSMStreamReaderDataRec),
                            "util/streamstub/read",
                            reader_message_handler, NULL);

  p->writer = ssh_oldfsm_spawn(fsm, sizeof(SshFSMStreamWriterDataRec),
                            "util/streamstub/write",
                            writer_message_handler, NULL);

  ssh_oldfsm_set_thread_name(thread, "parent");
  ssh_oldfsm_set_thread_name(p->reader, "reader");
  ssh_oldfsm_set_thread_name(p->writer, "writer");

  p->stream = stream;

  p->flags = shared_flags;

  r = ssh_oldfsm_get_tdata(p->reader);
  w = ssh_oldfsm_get_tdata(p->writer);

  r->stream = stream;
  r->in_buf = in_buf;
  r->in_buf_limit = in_buf_limit;
  r->flags = shared_flags;
  r->got_more = stub_has_read_more;
  r->in_buf_shrunk = in_buf_has_shrunk;
  r->own_flags = 0;

  w->stream = stream;
  w->out_buf = out_buf;
  w->flags = shared_flags;
  w->out_buf_shrunk = stub_has_written_some;
  w->data_present = out_buf_has_grown;
  w->own_flags = 0;

  r->parent = w->parent = thread;

  SSH_DEBUG(8, ("Setting the stream callback."));

  ssh_stream_set_callback(stream, stream_callback, thread);

  return thread;
}
