/*
    This file is part of the 'ears' package.
    Copyright (C) 1994,1995  Ralf Stephan <ralf@ark.franken.de>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/

#ifdef R_BP
#pragma implementation 
#include <sys/types.h>
#include <stdlib.h>
#include <malloc.h>
#include <stream.h>
#include <fstream.h>
#include <iomanip.h>
#include <MLCG.h>          // this is old stuff
#include <Uniform.h>
#include "mymath.h"
#include "time.h"
#include "r_bp.h"
#include "exception.h"

//============================ BP ==========================================
// The net is built not with the constructor but at the time when the number
// of patterns and output neurons is known.

void BP::build()
{
  ahid = new float [nhid];
  aout = new float [nout];
  
  whi = new float* [nhid];          // weights
  for (int k=0; k<nhid; k++)
    whi[k] = new float [nin];
  who = new float* [nhid];
  for (int k=0; k<nhid; k++)
    who[k] = new float [nout];
    
  empty=0;
}

void BP::dimension (int words, int patterns)
{
  npat = patterns;
  nout = words;
  word_pattern dummy;
  trace = dummy.F_trace;
  coeff = 8;                        // beware: coeff is now hardcoded
  nin = trace*coeff;
  nhid = nout/2;                  // estimated
  if (nhid<5) nhid=5;

  build();
      
  pat = new float* [npat];         // input patterns
  for (int k=0; k<npat; k++)
    pat[k] = new float [nin];
  
  out = new float* [npat];         // output patterns
  for (int k=0; k<npat; k++)
    out[k] = new float [nout];
}

BP::~BP()             // destructor
{
  if (empty) return;
  
  delete [] ahid;
  delete [] aout;
  
  for (int k=0; k<nhid; k++)
  {
    delete [] whi[k];
    delete [] who[k];
  }
  delete [] whi;
  delete [] who;  
  
  if (pat)
  {
    for (int k=0; k<npat; k++)
      delete [] pat[k];
    delete [] pat;
  }
  
  if (out)
  {
    for (int k=0; k<npat; k++)
      delete [] out[k];
    delete [] out;
  }
}

void BP::get (const String& fn, int wc, int pc)
{
  int d;
  ifstream pf(fn);
  String r;
  pf >> r;
  if (r!="BP" && r!="BPMT")
    { String s = fn + " is not a BP net!";
      throw(fatal_exception(s)); }
  pf >> d >> d;                    // check this eventually
  for (int k=0; k<nin; k++)
    pf >> pat[pc][k];
  for (int k=0; k<nout; k++) out[pc][k]=0.0;
  out[pc][wc] = 1.0;
}

//----------------------------Training BP-------------------------------

void BP::init_weights()
{
  Time t;
  int seed = t.seed();
  MLCG RG(seed&0x0f,seed);
  Uniform Rnd (-0.1, +0.1, &RG);
  
  for (int k=0; k<nhid; k++)
  {
    for (int l=0; l<nin; l++) whi[k][l] = Rnd();
    for (int l=0; l<nout; l++) who[k][l] = Rnd();
  }
}

inline float BP::errf (float t, float o) { return t-o; }
//{  return atanh(t-o); }
//{
//  float d=t-o;
//  if (d<0)
//    if (d<-thresh) return d;
//    else           return -0.01;
//  else
//    if (d>thresh)  return d;
//    else           return 0.01;
//}

void BP::train()
{
  init_weights();
  cerr << "Working..." << endl;
  
  int p,k,l, epoch=0, last_ps=0;
  float err=1e30;
  while (err>maxerr)
  {
    theta = 1.0-1.0/err;
    if (theta<0.2) theta=0.2;
    err=0.0;
    int ps=0;
    for (p=0; p<npat; p++)
    {
      ain = pat[p];
      for (k=0; k<nhid; k++)          // propagate to hidden layer
      {
        float sum=0;
        for (l=0; l<nin; l++)
          sum += ain[l]*whi[k][l];
        ahid[k] = 1/(1+exp(-sum));
      }

      int os=0;
      for (k=0; k<nout; k++)          // propagate to output layer
      {
        float sum=0;
        for (l=0; l<nhid; l++)
          sum += ahid[l]*who[l][k];
        aout[k] = 1/(1+exp(-sum));
        err += 0.5*(out[p][k]-aout[k])*(out[p][k]-aout[k]);
        if (fabs(out[p][k]-aout[k])<0.1) ++os;
      }
      if (os==nout) ++ps;
      
      for (k=0; k<nhid; k++)
      {
        float sum=0;
        for (l=0; l<nout; l++) 
        {
          float e = errf(out[p][l],aout[l]);
          who[k][l] += theta*ahid[k]*(aout[l]*(1.0-aout[l])+0.02)*e;
          sum += e*who[k][l];
        }
        for (l=0; l<nin; l++)
          whi[k][l] += theta*ain[l]*(ahid[k]*(1.0-ahid[k])+0.02)*sum;
      }
    }
    ++epoch;
    if (ps!=last_ps)
      cerr.form("Epoch: %4d  Error: %2.4f  Patterns learned: %3d/%3d\n",
                epoch,err,ps,npat);

    last_ps = ps;
    if (ps==npat) return;
  }
}

void BP::write (ostream& o)
{
  o << "BP\n";
  o << nin << " " << nhid << " " << nout << " " << trace << " " << coeff << " ";
  for (int k=0; k<nhid; k++)
  {
    for (int l=0; l<nin; l++) 
      o << whi[k][l] << " ";
    for (int l=0; l<nout; l++) 
      o << who[k][l] << " ";
  }
}

void BP::read  (istream& i) 
{
  String r;
  i >> r;
  if (r!="BP" && r!="BPMT")
    { cerr << "Not a BP net!" << endl; exit(1); }
  i >> nin >> nhid >> nout >> trace >> coeff;
  build();
  for (int k=0; k<nhid; k++)
  {
    for (int l=0; l<nin; l++) 
      i >> whi[k][l];
    for (int l=0; l<nout; l++) 
      i >> who[k][l];
  }
}

int BP::eval (const pattern& wp)
{
  float pat[nin];
  float** p = wp.buf();

  for (int k=0; k<trace; k++)  
    memcpy(&pat[k*coeff],&(p[k][0]),wp.coeff()*sizeof(float));

  int l;
  for (int k=0; k<nhid; k++)          // propagate to hidden layer
  {
    float sum=0;
    for (l=0; l<nin; l++)
      sum += pat[l]*whi[k][l];
    ahid[k] = 1/(1+exp(-sum));
  }
   
  for (int k=0; k<nout; k++)          // propagate to output layer
  {
    float sum=0;
    for (l=0; l<nhid; l++)
      sum += ahid[l]*who[l][k];
    aout[k] = 1/(1+exp(-sum));
  }

  float bestv=0.0;
  int best=-1;
  for (int k=0; k<nout; k++)          
    if (bestv < aout[k]) { bestv=aout[k]; best=k; }
    
  if (debug_)
  {
    cerr << endl;
    for (int k=0; k<nout; k++)
      cerr << "|" << k << " " << setprecision(3) << aout[k] << endl;
  }
  if (bestv < 0.5) return -1;
  else             return best;
}

//============================ BPMT ==========================================

void BPMT::train()
{
  init_weights();

  int p,k,l, epoch=0, last_ps=0;
  float mthi[nhid][nin], mtho[nhid][nout];
  const float alpha = 0.7;
  for (k=0; k<nhid; k++)
  {
    for (l=0; l<nin; l++) mthi[k][l]=0.0;
    for (l=0; l<nout; l++) mtho[k][l]=0.0;
  }

  float err=1e30;
  while (err>maxerr)
  {
    theta = 1.0-1.0/err;
    if (theta<0.2) theta=0.2;
    err=0.0;
    int ps=0;
    for (p=0; p<npat; p++)
    {
      ain = pat[p];
      for (k=0; k<nhid; k++)          // propagate to hidden layer
      {
        float sum=0;
        for (l=0; l<nin; l++)
          sum += ain[l]*whi[k][l];
        ahid[k] = 1/(1+exp(-sum));
      }

      int os=0;
      for (k=0; k<nout; k++)          // propagate to output layer
      {
        float sum=0;
        for (l=0; l<nhid; l++)
          sum += ahid[l]*who[l][k];
        aout[k] = 1/(1+exp(-sum));
        err += 0.5*errf(out[p][k],aout[k])*errf(out[p][k],aout[k]);
        if (fabs(out[p][k]-aout[k])<0.1) ++os;
      }
      if (os==nout) ++ps;
      
      for (k=0; k<nhid; k++)
      {
        float sum=0;
        for (l=0; l<nout; l++) 
        {
          float e = errf(out[p][l],aout[l]);
          float d = theta*ahid[k]*(aout[l]*(1.0-aout[l])+0.02)*e + alpha*mtho[k][l];
          who[k][l] += d;
          mtho[k][l] = d;
          sum += e*who[k][l];
        }
        for (l=0; l<nin; l++)
        {
          float d = theta*ain[l]*(ahid[k]*(1.0-ahid[k])+0.02)*sum + alpha*mthi[k][l];
          whi[k][l] += d;
          mthi[k][l] = d;
        }
      }
    }
    ++epoch;
//    if (ps!=last_ps)
      cerr.form("Epoch: %4d  Error: %2.4f  Patterns learned: %3d/%3d\n",
                epoch,err,ps,npat);

    last_ps = ps;
    if (ps==npat) return;
  }
}

#endif