#include "RcppArmadillo.h"
#include "aux_functions.h"

Rcpp::List model_latent_sparse(arma::mat X, arma::mat Y, arma::mat R, int k, int l, 
                        std::string method, int mcmc_samples, int burnin, 
                        bool verbose) 
{
  int n = X.n_rows;
  int p = X.n_cols;
  int m = Y.n_rows;
  int q = Y.n_cols;
  
  arma::mat Z = make_Z(X, Y, method);
  
  arma::uvec observed_vec = find_observed(R);
  arma::uvec unobserved_vec = find_unobserved(R);
  int observed_vec_size = observed_vec.size();
  int unobserved_vec_size = unobserved_vec.size();
  
  int var_dim;
  if (method == "bilinear"){
    var_dim = p * q;
  } else {
    var_dim = p + q;
  }
  
  arma::mat kappa = (R - 0.5 * (k + 1));
  arma::vec kappa_reduced = vectorise(kappa);
  kappa_reduced.shed_rows(unobserved_vec);
  
  arma::mat Z_reduced = Z;
  Z_reduced.shed_rows(unobserved_vec);
  arma::mat Z_reduced_t = Z_reduced.t();
  
  // set priors 
  arma::vec mu_B_0(var_dim, arma::fill::zeros);
  arma::mat Sigma_B_0(var_dim, var_dim, arma::fill::eye);
  arma::mat Sigma_B_0_inv = inv(Sigma_B_0);

  arma::vec mu_U_0(n*l, arma::fill::zeros);
  arma::mat Sigma_U_0(n*l, n*l, arma::fill::eye);
  arma::mat Sigma_U_0_inv = inv(Sigma_U_0);

  arma::vec mu_V_0(m*l, arma::fill::zeros);
  arma::mat Sigma_V_0(m*l, m*l, arma::fill::eye);
  arma::mat Sigma_V_0_inv = inv(Sigma_V_0);

  // saved posterior draws for Beta and R
  arma::cube B_hat(var_dim, 1, mcmc_samples - burnin, arma::fill::zeros);
  arma::cube U_hat(n, l, mcmc_samples - burnin, arma::fill::zeros);
  arma::cube V_hat(m, l, mcmc_samples - burnin, arma::fill::zeros);
  arma::cube R_hat(n, m, mcmc_samples - burnin, arma::fill::zeros);
    
  // single posterior draw for Beta and R
  arma::vec B_vec_est(var_dim, arma::fill::zeros);
  arma::vec U_vec_est(n*l, arma::fill::zeros);
  arma::vec V_vec_est(m*l, arma::fill::zeros);
  arma::mat R_pred_mat(n, m, arma::fill::zeros);
  
  // Bayesian horseshoe priors (induce sparsity on B, U and V)
  double zeta_B = 1.0;
  double zeta_U = 1.0;
  double zeta_V = 1.0;
  double tau_B = 1.0;
  double tau_U = 1.0;
  double tau_V = 1.0;
  arma::vec nu_B(var_dim, arma::fill::ones);
  arma::vec nu_U(n*l, arma::fill::ones);
  arma::vec nu_V(m*l, arma::fill::ones);
  arma::vec lambda_B(var_dim, arma::fill::randu);
  arma::vec lambda_U(n*l, arma::fill::randu);
  arma::vec lambda_V(m*l, arma::fill::randu);
  
  // identity matrices
  arma::mat I_m(m, m, arma::fill::eye);
  arma::mat I_n(n, n, arma::fill::eye);
  
  for(int iter = 0; iter < mcmc_samples; ++iter){
    if(verbose){
      Rcpp::Rcout << "This is iteration: " << iter + 1 << " out of " << mcmc_samples << "\n";
    }
    
    arma::mat U_est = vec_2_mat(U_vec_est, n, l);
    arma::mat V_est = vec_2_mat_byrow(V_vec_est, m, l);

    arma::mat U_kron_est = kron(I_m, U_est);
    arma::mat V_kron_est = kron(V_est, I_n);
    
    arma::mat U_kron_est_reduced = U_kron_est;
    arma::mat V_kron_est_reduced = V_kron_est;
    U_kron_est_reduced.shed_rows(unobserved_vec);
    V_kron_est_reduced.shed_rows(unobserved_vec);

    arma::vec omega(n*m, arma::fill::zeros);
    arma::vec u_v_prod = vectorise(U_est * V_est.t());
    for(int i = 0; i < observed_vec_size; ++i){
      int row_i = observed_vec(i);
      omega(row_i) = rcpp_pgdraw(k - 1, dot(Z.row(row_i), B_vec_est) + u_v_prod(row_i));
    }
    
    arma::vec omega_reduced = omega;
    omega_reduced.shed_rows(unobserved_vec);
    arma::mat Omega(observed_vec_size, observed_vec_size, arma::fill::zeros);
    Omega.diag() = omega.elem(observed_vec);
    arma::vec xi_reduced = kappa_reduced/omega_reduced;

    // sample vec(B)
    arma::mat Sigma_B = inv(Z_reduced_t * (Omega * Z_reduced) + Sigma_B_0_inv);
    arma::vec mu_B = Sigma_B * (((Z_reduced_t * Omega) * (xi_reduced - (U_kron_est_reduced * V_vec_est))) + Sigma_B_0_inv * mu_B_0);

    arma::mat mv_sample_B = mvrnorm_arma(1, mu_B, Sigma_B);
    B_vec_est = vectorise(mv_sample_B);

    for(int i = 0; i < var_dim; ++i){
      lambda_B(i) = 1.0/R::rgamma(1.0, 1.0/(1.0/nu_B(i) + pow(B_vec_est(i), 2)/(2.0*tau_B)));
      if(lambda_B(i) < 1e-9){
        lambda_B(i) = 1e-9;
      }
    }

    tau_B = 1.0/R::rgamma(0.5*(var_dim + 1.0), 1.0/(1.0/zeta_B + 0.5*sum(pow(B_vec_est, 2)/lambda_B)));
    if(tau_B < 1e-9){
      tau_B = 1e-9;
    } 
    
    for(int i = 0; i < var_dim; ++i){
      nu_B(i) = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/lambda_B(i)));
      if(nu_B(i) < 1e-9){
        nu_B(i) = 1e-9;
      }
    }

    zeta_B = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/tau_B));
    if(zeta_B < 1e-9){
      zeta_B = 1e-9;
    } 
    
    Sigma_B_0_inv.diag() = 1.0/(tau_B*lambda_B);

    // sample vec(U)
    arma::mat Sigma_U = inv(V_kron_est_reduced.t() * (Omega * V_kron_est_reduced) + Sigma_U_0_inv);
    arma::vec mu_U = Sigma_U * (((V_kron_est_reduced.t() * Omega) * (xi_reduced - (Z_reduced * B_vec_est))) + Sigma_U_0_inv * mu_U_0);

    arma::mat mv_sample_U = mvrnorm_arma(1, mu_U, Sigma_U);
    U_vec_est = vectorise(mv_sample_U);
    U_est = vec_2_mat(U_vec_est, n, l);
    U_kron_est = kron(I_m, U_est);
    U_kron_est_reduced = U_kron_est;
    U_kron_est_reduced.shed_rows(unobserved_vec);

    for(int i = 0; i < n*l; ++i){
      lambda_U(i) = 1.0/R::rgamma(1.0, 1.0/(1.0/nu_U(i) + pow(U_vec_est(i), 2)/(2.0*tau_U)));
      if(lambda_U(i) < 1e-9){
        lambda_U(i) = 1e-9;
      }
    }
    tau_U = 1.0/R::rgamma(0.5*(n*l+1.0), 1.0/(1.0/zeta_U + 0.5*sum(pow(U_vec_est, 2)/lambda_U)));
    if(tau_U < 1e-9){
      tau_U = 1e-9;
    }  
    
    for(int i = 0; i < n*l; ++i){
      nu_U(i) = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/lambda_U(i)));
      if(nu_U(i) < 1e-9){
        nu_U(i) = 1e-9;
      }
    }

    zeta_U = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/tau_U));
    if(zeta_U < 1e-9){
      zeta_U = 1e-9;
    }  
    
    Sigma_U_0_inv.diag() = 1.0/(tau_U*lambda_U);
    
    // sample vec(V)
    arma::mat Sigma_V = inv(U_kron_est_reduced.t() * (Omega * U_kron_est_reduced) + Sigma_V_0_inv);
    arma::vec mu_V = Sigma_V * (((U_kron_est_reduced.t() * Omega) * (xi_reduced - (Z_reduced * B_vec_est))) + Sigma_V_0_inv * mu_V_0);
    
    arma::mat mv_sample_V = mvrnorm_arma(1, mu_V, Sigma_V);
    V_vec_est = vectorise(mv_sample_V);
    V_est = vec_2_mat_byrow(V_vec_est, m, l);
    V_vec_est = vectorise(V_est.t());
    V_kron_est = kron(V_est, I_n);
    V_kron_est_reduced = V_kron_est;
    V_kron_est_reduced.shed_rows(unobserved_vec); 
    
    for(int i = 0; i < m*l; ++i){
      lambda_V(i) = 1.0/R::rgamma(1.0, 1.0/(1.0/nu_V(i) + pow(V_vec_est(i), 2)/(2.0*tau_V)));
      if(lambda_V(i) < 1e-9){
        lambda_V(i) = 1e-9;
      }
    }
    tau_V = 1.0/R::rgamma(0.5*(m*l+1.0), 1.0/(1.0/zeta_V + 0.5*sum(pow(V_vec_est, 2)/lambda_V)));
    if(tau_V < 1e-9){
      tau_V = 1e-9;
    }  
    
    for(int i = 0; i < m*l; ++i){
      nu_V(i) = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/lambda_V(i)));
      if(nu_V(i) < 1e-9){
        nu_V(i) = 1e-9;
      }
    }
    
    zeta_V = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/tau_V));
    if(zeta_V < 1e-9){
      zeta_V = 1e-9;
    }  
    
    Sigma_V_0_inv.diag() = 1.0/(tau_V*lambda_V);

    if(iter >= burnin){
      B_hat.slice(iter - burnin) = B_vec_est;
      U_hat.slice(iter - burnin) = U_est;
      V_hat.slice(iter - burnin) = V_est;
      
      // # predict NA in R
      arma::vec binom_samp(unobserved_vec_size, arma::fill::zeros);
      u_v_prod = vectorise(U_est * V_est.t());
      for(int i = 0; i < unobserved_vec_size; ++i){
        int row_i = unobserved_vec(i);
        double lin_pred_i = logit(dot(Z.row(row_i), B_vec_est) + u_v_prod(row_i));
        binom_samp(i) = R::rbinom(k - 1.0, lin_pred_i); // this seems to work okay
      }

      arma::vec R_pred(n*m, arma::fill::zeros);
      R_pred.elem(unobserved_vec) = binom_samp + 1; // seems to work too
      R_pred_mat = vec_2_mat(R_pred, n, m);

      R_hat.slice(iter - burnin) = R_pred_mat; // here it goes wrong
    }
  }
  return Rcpp::List::create(Rcpp::Named("B_hat") = B_hat,
                            Rcpp::Named("U_hat") = U_hat,
                            Rcpp::Named("V_hat") = V_hat,
                            Rcpp::Named("R_hat") = R_hat);
}
