/*
    This file is part of the 'ears' package.
    The cookbook routines are Copyright (C) 1993 Tony Robinson
    The rest is Copyright (C) 1994,1995  Ralf Stephan <ralf@ark.franken.de>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
#ifdef R_DTW
#pragma implementation 
#include "r_dtw.h"
#include <malloc.h>
#include <stdio.h>
#include <stdlib.h>
#include <stream.h>
#include <fstream.h>
#include <algo.h>
#include "others/time.h"
#include "ears/exception.h"

//===============================DTW======================================
//------------------------cookbook routines-------------------------------
//                Copyright (C) 1993 Tony Robinson
//------------------------------------------------------------------------
// You can find the full package on 
// svr-ftp.eng.cam.ac.uk: /pub/comp.speech/sources/cookbook*.tar.Z
//------------------------------------------------------------------------

# define REALLY_VERY_BIG	(1e30)

void **Smatrix(size_t nitem0, size_t nitem1, size_t size) {
  size_t nbyte;
  void **array0;
  char  *array1;
  unsigned int    i;

  nbyte = nitem0 * sizeof(char*) + nitem0 * nitem1 * size;

  if(nbyte == 0) nbyte = 1;

  array0 = (void**)malloc(nbyte);

//  if(array0 == NULL)
//    Sperror("Smatrix(%d, %d, %d)", nitem0, nitem1, size);

  array1 = (char*) (array0 + nitem0);

  for(i = 0; i < nitem0; i++)
    array0[i] = array1 + i * nitem1 * size;

  return(array0);
}

/* Compute the Euclidean distance, or L2 norm, between two arrays */ 
inline float Sl2norm(float *vec0, float* vec1, int nitem) {
  float sum = 0.0;
  int i;

  for(i = 0; i < nitem; i++) {
    float diff = vec0[i] - vec1[i];
    sum += diff * diff;
//  sum+=vec0[i]>vec1[i]?vec0[i] - vec1[i]:vec1[i] - vec0[i];
  }

  return(sum);
}

float DTW::do_dtw (float **unkTemp, int nunkTemp, float **refTemp, 
                          int nrefTemp, int tempSize) const
{
  int i, j;
  float **localDist, **globalDist;
  float bestDist;

  localDist  = (float**) Smatrix(nunkTemp, nrefTemp, sizeof(**localDist));
  globalDist = (float**) Smatrix(nunkTemp, nrefTemp, sizeof(**globalDist));

  /* compute and store all the local distances */
  for(i = 0; i < nunkTemp; i++)
    for(j = 0; j < nrefTemp; j++)
      localDist[i][j] = Sl2norm(unkTemp[i], refTemp[j], tempSize);

  /* for the first frame the only possible match is at (0, 0) */
  globalDist[0][0] = localDist[0][0];
  for(j = 1; j < nrefTemp; j++)
    globalDist[0][j] = REALLY_VERY_BIG;

  /* in the second frame the only valid state is (1, 1) */
  globalDist[1][0] = REALLY_VERY_BIG;
  globalDist[1][1] = globalDist[0][0] + localDist[1][1];
  for(j = 2; j < nrefTemp; j++)
    globalDist[1][j] = REALLY_VERY_BIG;

  /* and do for the general case of the rest of the frames */
  for(i = 2; i < nunkTemp; i++) {
    
    /* */
    globalDist[i][0] = REALLY_VERY_BIG;
    globalDist[i][1] = globalDist[i - 1][0] + localDist[i][1];
  
    for(j = 2; j < nrefTemp; j++) {
      float topPath, midPath, botPath;

//      if(globalDist[i-2][j-1] < 0.0)
//	Spanic("globalDist[i-2][j-1] < 0.0\n: %d\t%d\t%f\n", i, j, globalDist[i-2][j-1]);

      topPath = globalDist[i-2][j-1] + localDist[i-1][j] + localDist[i][j];
      midPath = globalDist[i-1][j-1] + localDist[i][j];
      botPath = globalDist[i-1][j-2] + localDist[i][j-1] + localDist[i][j];

// Type II local constraints without slope weights

      /* find and store the smallest gloabal distance */
      if(topPath < midPath && topPath < botPath)
	globalDist[i][j] = topPath;
      else if (midPath < botPath)
	globalDist[i][j] = midPath;
      else
	globalDist[i][j] = botPath;
    }    
  }

  bestDist = globalDist[nunkTemp - 1][nrefTemp - 1];

  free(localDist);
  free(globalDist);

  return(bestDist);
}

//--------------------interface to cookbook routines----------------------

void DTW::write (ostream &o)
{
  for (DTWList::const_iterator it = list.begin(); it!=list.end(); it++) 
    o << (*it)->fn 
      << " " << (*it)->wc 
      << " " << endl;
}  

// when reading a DTW, 'list' is filled with DTW_entries that
// contain the reference patterns. 

void DTW::read (istream &i)
{
  string fn,r;
  i >> r;
  if (r == "DTW" || r == "NDTW") ; // gcc-2.7.2 workaround
  else
    throw(fatal_exception("Net file is not a DTW file!"));

  int wc;
  while (1)
  {
    i >> fn >> ws >> wc;
    if (i.eof()) break;
    DTW_entry& e = *new DTW_entry;  
    e.fn = fn;
    e.wc = wc;
    ifstream patf (fn.c_str());
    e.p.read (patf);
    if (e.p.bad()) 
    {
      string t="Bogus or nonexisting pattern file.  Please remove "+fn+"\nand train the word again!\n";
      throw(fatal_exception(t)); 
    }
    list.push_back(&e);
    ++nref;
  }
}

static int dist_type_compare(const void *foo, const void *bar) {
  return(((dist_type*)foo)->distance > ((dist_type*)bar)->distance ? 1 : -1);
}

int DTW::eval (const pattern& wp)
{
  if (wp.length()==0) 
    throw(fatal_exception("DTW::eval(): wp.length==0\n"));

  Time t;
  if (globdist) delete [] globdist;
  globdist = new dist_type [nref];

  int i=0;  
  best_hypothesis = REALLY_VERY_BIG;
  for (DTWList::const_iterator it = list.begin(); it!=list.end(); it++,i++) 
  {
    var_pattern& ref = (*it)->p;
     
    float best = do_dtw (wp.buf(), wp.length(), ref.buff(),
				 ref.len(), wp.coeff());
    if (best < best_hypothesis) best_hypothesis=best; 
    globdist[i].distance = best;
    globdist[i].index = i;
    globdist[i].fn = (*it)->fn;
    globdist[i].wc = (*it)->wc;
  }

  /* sort the global distances putting the lowest score first */
  sort(globdist, globdist+nref);

  ms_ = t.stop();

  int res = globdist[0].wc;
  
  if (globdist[0].distance > huh_) res=-1;
  
  return res;
}

SList& DTW::deb_info()
{
//  if (!globdist) throw;

  debug_l_.erase (debug_l_.begin(), debug_l_.end());
  debug_l_.push_back(" Dist. |  File");
  debug_l_.push_back("----------------------");
  const num_items = 3;
  char s[128];
  for (int k=0; k<num_items; k++)
  {
    string t = globdist[k].fn;
    t = t.substr(t.find_last_of('/')+1);               // strip leading path
    sprintf(s,"%5.2f  | %-20s",
            globdist[k].distance<=99.99 ? globdist[k].distance : 99.99,
            t.c_str());
    debug_l_.push_back(s);
  }

  debug_l_.push_back("");
  sprintf(s,"Pure DTW time:  %4d ms",int(ms_/1000));
  debug_l_.push_back(s);

  delete [] globdist;
  globdist=0;
  return debug_l_;
}

//================================== NDTW ==================================


float NDTW::do_dtw(float **unkTemp, int nunkTemp, float **refTemp, int nrefTemp,
                 int tempSize) const
{
  int i, j;
  float bestDist;
  
  if (nunkTemp > mr*3)
    nunkTemp = mr*3;
  if (nunkTemp*2 <= nrefTemp || nunkTemp >= nrefTemp*2)
    return REALLY_VERY_BIG;
    
  /* for the first frame the only possible match is at (0, 0) */
  glbd[0][0] = locd[0][0];
  for(j = 1; j < nrefTemp; j++)
    glbd[0][j] = REALLY_VERY_BIG;

  /* in the second frame the only valid state is (1, 1) */
  glbd[1][0] = REALLY_VERY_BIG;
  glbd[1][1] = glbd[0][0] + Sl2norm(unkTemp[1], refTemp[1], tempSize);
  for(j = 2; j < nrefTemp; j++)
    glbd[1][j] = REALLY_VERY_BIG;

  /* and do for the general case of the rest of the frames */
  for(i = 2; i < nunkTemp; i++) {

    const from = float(i)<0.67*float(2*nunkTemp-nrefTemp)? 1+i/2 : (i-nunkTemp)*2+nrefTemp+1;
    const to   = float(i)<0.34*float(2*nrefTemp-nunkTemp)? 2*i : (i-nunkTemp)/2+nrefTemp-1;

    for (j = from-2; j<from; j++)
      glbd[i][j] = REALLY_VERY_BIG;
    for (j = from-1; j<to; j++)
        locd[i][j] = Sl2norm(unkTemp[i], refTemp[j], tempSize);
    for (j = to; j<nrefTemp; j++)
      glbd[i][j] = REALLY_VERY_BIG;

    float best_in_this_row = REALLY_VERY_BIG;
    for(j = from; j <= to; j++) {

      float topPath, midPath, botPath;

// Type II local constraints with slope weights (Rabiner: Fundamentals, p.223)
      topPath = glbd[i-2][j-1] + 0.7*(locd[i-1][j] + locd[i][j]);
      midPath = glbd[i-1][j-1] + locd[i][j];
      botPath = glbd[i-1][j-2] + 0.7*(locd[i][j-1] + locd[i][j]);

      /* find and store the smallest gloabal distance */
      float best;
      if(topPath < midPath && topPath < botPath)
	best = topPath;
      else if (midPath < botPath)
	best = midPath;
      else
	best = botPath;
	
      if (best < best_in_this_row) best_in_this_row=best;
      glbd[i][j] = best;
    }    
    if (best_in_this_row > best_hypothesis) 
      return REALLY_VERY_BIG; 
  }

  bestDist = glbd[nunkTemp - 1][nrefTemp - 1];
  return(bestDist);
}

// when reading a DTW, 'list' is filled with DTW_entries that
// contain the reference patterns. 

void NDTW::read (istream &i)
{
  string fn,r;
  i >> r;
  if (r == "DTW" || r == "NDTW") ; // gcc-2.7.2 workaround
  else
    throw(fatal_exception("Net file is not a DTW file!"));
  int wc,coeff;
  while (1)
  {
    i >> fn >> ws >> wc;
    if (i.eof()) break;
    DTW_entry& e = *new DTW_entry;  
    e.fn = fn;
    e.wc = wc;
    ifstream patf (fn.c_str());
    e.p.read (patf);
    if (e.p.bad()) 
    {
      string t="Bogus or nonexisting pattern file.  Please remove "+fn+"\nand train the word again!\n";
      throw(fatal_exception(t)); 
    }
    if (mr < e.p.len()) mr=e.p.len();
    coeff = e.p.co();
    list.push_back(&e);
    ++nref;
  }
  
  locd = new float*[mr*3];
  for (int k=0; k<mr*3; k++)
    locd[k] = new float[mr];
  glbd = new float*[mr*3];
  for (int k=0; k<mr*3; k++)
    glbd[k] = new float[mr];
  empty=0;
}

NDTW::~NDTW()
{ 
  if (empty) return;

  for (int k=0; k<mr*3; k++) 
    delete [] locd[k]; 
  delete [] locd;

  for (int k=0; k<mr*3; k++) 
    delete [] glbd[k]; 
  delete [] glbd;
}  

#endif