/* $Id: ssh-keyscan.c,v 1.3 1995/11/30 11:55:18 dm Exp $ */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdarg.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/uio.h>
#include <sys/resource.h>
#include <netinet/in.h>
#include <string.h>
#include <errno.h>
#include <ctype.h>
#include <netdb.h>
#include "queue.h"
#include <config.h>
#include <gmp.h>
#include <fdlim.h>

extern int errno;
extern const char *const sys_errlist[];
#define errmsg sys_errlist[errno]

#include "linebuf.h"

#define TIMEOUT 5
#define PORT 22
#define MAXMAXFD 256

int maxfd;
#define maxcon (maxfd - 10)

char *prog;
fd_set read_wait;
int ncon;

typedef struct Connection {
  unsigned char c_status;
#define CS_UNUSED 0
#define CS_CON 1
#define CS_SIZE 2
#define CS_KEYS 3
  int c_fd;
  int c_plen;
  int c_len;
  int c_off;
  char *c_name;
  char *c_output_name;
  char *c_data;
  struct timeval c_tv;
  TAILQ_ENTRY(Connection) c_link;
} con;
TAILQ_HEAD(conlist, Connection) tq;  /* Timeout Queue */

con *fdcon;

void fatal (const char *msg, ...)
     __attribute__ ((noreturn, format (printf, 1, 2)));
void
fatal (const char *msg, ...)
{
  va_list ap;

  va_start (ap, msg);
  vfprintf (stderr, msg, ap);
  va_end (ap);
  exit (1);
}

void warn (const char *msg, ...) __attribute__ ((format (printf, 1, 2)));
void
warn (const char *msg, ...)
{
  va_list ap;

  fprintf (stderr, "%s: ", prog);
  va_start (ap, msg);
  vfprintf (stderr, msg, ap);
  va_end (ap);
}

void *
xmalloc (unsigned int size)
{
  void *p;
  p = malloc (size);
  if (p == NULL)
    fatal ("%s: malloc of %d bytes failed\n", prog, size);
  return (p);
}

char *
xstrdup (char *s)
{
  char *d;

  if (s == NULL)
    return (NULL);
  d = xmalloc (strlen (s) + 1);
  strcpy (d, s);
  return (d);
}

unsigned char
getbyte (char **cpp, char *eom)
{
  if (*cpp >= eom) {
    *cpp = eom + 1;
    return (0);
  }
  return (*(*cpp)++);
}

unsigned short
getshort (char **cpp, char *eom)
{
  short n = 0;

  if (*cpp + 2 > eom) {
    *cpp = eom + 1;
    return (0);
  }
  bcopy (*cpp, &n, 2);
  *cpp += 2;
  return (htons (n));
}

unsigned int
getlong (char **cpp, char *eom)
{
  unsigned int n = 0;

  if (*cpp + 4 > eom) {
    *cpp = eom + 1;
    return (0);
  }
  bcopy (*cpp, &n, 4);
  *cpp += 4;
  return (htonl (n));
}

/* Taken from ssh */
void
getmp (char **cpp, MP_INT *val, char *eom)
{
  int i, bits, bytes;
  char *str;

  bits = getshort (cpp, eom);
  bytes = (bits + 7) / 8;
  str = xmalloc (2 * bytes + 2);
  str[0] = '\0';                         /* in case bytes is 0 */
  for (i = 0; i < bytes; i++)
    sprintf(&str[2*i], "%02x", getbyte (cpp, eom));
  mpz_set_str (val, str, 16);
  free (str);
}

void
skipn (char **cpp, int n, char *eom)
{
  if (*cpp > eom)
    *cpp = eom + 1;
  else
    *cpp += n;
}

void
keyprint (char *host, char *output_name, char *kd, int len)
{
  int b;
  MP_INT e, m;
  char *eom;

  mpz_init (&e);
  mpz_init (&m);
  kd += 8 - (len & 7);
  eom = kd + len;
  if (getbyte (&kd, eom) != 2) {
    warn ("%s: invalid packet type\n", host);
    return;
  }
  skipn (&kd, 8, eom);
  getlong (&kd, eom);
  getmp (&kd, &e, eom);
  getmp (&kd, &m, eom);

  b = getlong (&kd, eom);
  getmp (&kd, &e, eom);
  getmp (&kd, &m, eom);

  if (kd > eom)
    warn ("%s: packet too short\n", host);
  else {
    printf ("%s %d ", output_name ? output_name : host, b);
    mpz_out_str (stdout, 10, &e);
    printf (" ");
    mpz_out_str (stdout, 10, &m);
    printf ("\n");
  }
}

int
setaddr (struct sockaddr_in *sa, char *hname)
{
  sa->sin_addr.s_addr = inet_addr (hname);
  if (!isdigit (hname[0]) || sa->sin_addr.s_addr == 0
      || (int) sa->sin_addr.s_addr == -1) {
    struct hostent *hp;
    hp = gethostbyname (hname);
    if (hp == NULL || hp->h_length < sizeof (sa->sin_addr)) {
      warn ("could not get host address for `%s'\n", hname);
      return (-1);
    }
    bcopy (hp->h_addr_list[0], &sa->sin_addr, sizeof (sa->sin_addr));
  }
  return (0);
}

int
tcpconnect (char *hname)
{
  struct sockaddr_in src, dst;
  int s;

  if (setaddr (&dst, hname) < 0)
    return (-1);
  dst.sin_family = AF_INET;
  dst.sin_port = htons (PORT);

  src.sin_family = AF_INET;
  src.sin_addr.s_addr = htonl (INADDR_ANY);
  src.sin_port = htons (0);

  s = socket (AF_INET, SOCK_STREAM, 0);
  if (s < 0)
    fatal ("socket: %s\n", errmsg);
  if (fcntl (s, F_SETFL, O_NDELAY) < 0)
    fatal ("F_SETFL: %s\n", errmsg);
  if (bind (s, (struct sockaddr *) &src, sizeof (src)) < 0)
    fatal ("bind: %s\n", errmsg);
  if (connect (s, (struct sockaddr *) &dst, sizeof (dst)) < 0
      && errno != EINPROGRESS)
    warn ("connect (`%s'): %s\n", hname, errmsg);

  return (s);
}

int
conalloc (char *hname)
{
  int s;
  char *p;

  p = strtok (hname, " \t\n");
  if (p == NULL)
    return (-1);

  s = tcpconnect (p);
  if (s < 0)
    return (-1);
  if (s >= maxfd)
    fatal ("conalloc: fdno %d too high\n", s);
  if (fdcon[s].c_status)
    fatal ("conalloc: attempt to reuse fdno %d\n", s);

  fdcon[s].c_fd = s;
  fdcon[s].c_status = CS_CON;
  fdcon[s].c_name = xstrdup (p);
  p = strtok (NULL, " \t\n");
  fdcon[s].c_output_name = p ? xstrdup (p) : NULL;
  fdcon[s].c_data = (char *) &fdcon[s].c_plen;
  fdcon[s].c_len = 4;
  fdcon[s].c_off = 0;
  gettimeofday (&fdcon[s].c_tv, NULL);
  fdcon[s].c_tv.tv_sec += TIMEOUT;
  TAILQ_INSERT_TAIL (&tq, &fdcon[s], c_link);
  FD_SET (s, &read_wait);
  ncon++;
  return (s);
}

void
confree (int s)
{
  close (s);
  if (s >= maxfd || fdcon[s].c_status == CS_UNUSED)
    fatal ("confree: attempt to free bad fdno %d\n", s);
  free (fdcon[s].c_name);
  if (fdcon[s].c_status == CS_KEYS)
    free (fdcon[s].c_data);
  fdcon[s].c_status = CS_UNUSED;
  TAILQ_REMOVE (&tq, &fdcon[s], c_link);
  FD_CLR (s, &read_wait);
  ncon--;
}

void
contouch (int s)
{
  TAILQ_REMOVE (&tq, &fdcon[s], c_link);
  gettimeofday (&fdcon[s].c_tv, NULL);
  fdcon[s].c_tv.tv_sec += TIMEOUT;
  TAILQ_INSERT_TAIL (&tq, &fdcon[s], c_link);
}

void
congreet (int s)
{
  char buf[80];
  int n;
  con *c = &fdcon[s];

  n = read (s, buf, sizeof (buf));
  if (n < 0) {
    if (errno != ECONNREFUSED)
      warn ("read (%s): %s\n", c->c_name, errmsg);
    confree (s);
    return;
  }
  if (buf[n - 1] != '\n') {
    warn ("%s: bad greeting\n", c->c_name);
    confree (s);
    return;
  }
  if (write (s, buf, n) != n) {
    warn ("write (%s): %s\n", c->c_name, errmsg);
    confree (s);
    return;
  }
  c->c_status = CS_SIZE;
  contouch (s);
}

void
conread (int s)
{
  int n;
  con *c = &fdcon[s];

  if (c->c_status == CS_CON) {
    congreet (s);
    return;
  }

  n = read (s, c->c_data + c->c_off, c->c_len - c->c_off);
  if (n < 0) {
    warn ("read (%s): %s\n", c->c_name, errmsg);
    confree (s);
    return;
  }
  c->c_off += n;

  if (c->c_off == c->c_len)
    switch (c->c_status) {
    case CS_SIZE:
      c->c_plen = htonl (c->c_plen);
      c->c_len = c->c_plen + 8 - (c->c_plen & 7);
      c->c_off = 0;
      c->c_data = xmalloc (c->c_len);
      c->c_status = CS_KEYS;
      break;
    case CS_KEYS:
      keyprint (c->c_name, c->c_output_name, c->c_data, c->c_plen);
      confree (s);
      return;
      break;
    default:
      fatal ("conread: invalid status %d\n", c->c_status);
      break;
    }

  contouch (s);
}

void
conloop (void)
{
  fd_set r, e;
  struct timeval seltime, now;
  int i;
  con *c;

  gettimeofday (&now, NULL);
  c = tq.tqh_first;

  if (c && (c->c_tv.tv_sec > now.tv_sec
	    || (c->c_tv.tv_sec == now.tv_sec
		&& c->c_tv.tv_usec > now.tv_usec))) {
    seltime = c->c_tv;
    seltime.tv_sec -= now.tv_sec;
    seltime.tv_usec -= now.tv_usec;
    if ((int) seltime.tv_usec < 0) {
      seltime.tv_usec += 1000000;
      seltime.tv_sec--;
    }
  }
  else
    seltime.tv_sec = seltime.tv_usec = 0;

  r = e = read_wait;
  select (maxfd, &r, NULL, &e, &seltime);
  for (i = 0; i < maxfd; i++)
    if (FD_ISSET (i, &e)) {
      warn ("%s: exception!\n", fdcon[i].c_name);
      confree (i);
    }
    else if (FD_ISSET (i, &r))
      conread (i);

  c = tq.tqh_first;
  while (c && (c->c_tv.tv_sec < now.tv_sec
	       || (c->c_tv.tv_sec == now.tv_sec
		   && c->c_tv.tv_usec < now.tv_usec))) {
    int s = c->c_fd;
    c = c->c_link.tqe_next;
    confree (s);
  }
}

char *
nexthost (int argc, char **argv)
{
  static int an = 1;
  static Linebuf *lb;

  for (;;) {
    if(!lb) {
      if (an >= argc)
	return (NULL);
      if (argv[an][0] != '-')
	return (argv[an++]);
      if (!strcmp (argv[an], "--")) {
	if (++an >= argc)
	  return (NULL);
	return (argv[an++]);
      }
      else if (!strncmp (argv[an], "-f", 2)) {
	char *fname;
	if (argv[an][2])
	  fname = &argv[an++][2];
	else if (++an >= argc) {
	  warn ("missing filename for `-f'\n");
	  return (NULL);
	}
	else
	  fname = argv[an++];
	if (!strcmp (fname, "-"))
	  fname = NULL;
	lb = Linebuf_alloc (fname, warn);
      }
      else
	warn ("ignoring invalid option `%s'\n", argv[an++]);
    }
    else {
      char *line;
      line = getline (lb);
      if (line)
	return (line);
      Linebuf_free (lb);
      lb = NULL;
    }
  }
}

int
main (int argc, char **argv)
{
  char *host;

  TAILQ_INIT (&tq);

  if ((prog = strrchr (argv[0], '/')))
    prog++;
  else
    prog = argv[0];

  if (argc < 2)
    fatal ("usage: %s { [--] host | -f file } ...\n", prog);

  maxfd = fdlim_get (1);
  if (maxfd < 0)
    fatal ("%s: fdlim_get: bad value\n", prog);
  if (maxfd > MAXMAXFD)
    maxfd = MAXMAXFD;
  if (maxcon <= 0)
    fatal ("%s: not enough file descriptors\n", prog);
  if (maxfd > fdlim_get (0))
    fdlim_set (maxfd);
  fdcon = xmalloc (maxfd * sizeof (con));

  do {
    while (ncon < maxcon) {
      host = nexthost (argc, argv);
      if (host == NULL)
	break;
      conalloc (host);
    }
    conloop ();
  } while (host);
  while (ncon > 0)
    conloop ();

  return (0);
}
