#include "InverseRL.h"
#include <cmath>

double computeTrajectoryLik(const RatData& ratdata, int session, Strategy& strategy)
{
    bool debug = false;
    std::string creditAssignment = strategy.getLearningRule();
    //std::cout << "creditAssignment=" << creditAssignment << std::endl;
    double lik;
    if(creditAssignment == "aca2")
    {
      lik =   getAca2SessionLikelihood(ratdata, session, strategy);
    }
    else if(creditAssignment == "arl")
    {
      lik =   getAvgRwdQLearningLik(ratdata, session, strategy);
    }
    else if(creditAssignment == "drl")
    {
      lik =   getDiscountedRwdQlearningLik(ratdata, session, strategy);
    }

    if (std::isnan(lik)) {
                    
      std::cout << "Likelihood is nan. Check" << std::endl;
      std::exit(EXIT_FAILURE);
    }

    return(lik);
}

std::pair<arma::mat, arma::mat> simulateTrajectory(const RatData& ratdata, int session, Strategy& strategy)
{
    bool debug = false;
    std::string creditAssignment = strategy.getLearningRule();
    //std::cout << "creditAssignment=" << creditAssignment << std::endl;
    std::pair<arma::mat, arma::mat> simData;
    if(creditAssignment == "aca2")
    {
       simData = simulateAca2(ratdata, session, strategy);
    }
    else if(creditAssignment == "arl")
    {
      simData = simulateAvgRwdQLearning(ratdata, session, strategy);
    }
    else if(creditAssignment == "drl")
    {
      simData = simulateDiscountedRwdQlearning(ratdata, session, strategy);
    }

    
    return(simData);
}

void printFirst5Rows(const arma::mat& matrix, std::string matname) {
  std::cout << "Printing first 5 rows of <<" << matname << std::endl;
  for (int i = 0; i < 5; ++i) {
    std::cout << "Row " << (i + 1) << ": ";
    for (int j = 0; j < matrix.n_cols; ++j) {
      std::cout << matrix(i, j) << " ";
    }
    std::cout << std::endl;
  }
}



std::vector<std::string> generatePathTrajectory(Strategy& strategy, BoostGraph* graph, BoostGraph::Vertex rootNode)
{
  std::vector<BoostGraph::Edge> edges;
  std::vector<std::string> turns;
  BoostGraph::Vertex node;
  node = rootNode;
  edges = graph->getOutGoingEdges(node);
      
  while (!edges.empty())
  {
    BoostGraph::Vertex childSelected = graph->sampleChild(node);
    std::string turnSelected = graph->getNodeName(childSelected);
    turns.push_back(turnSelected);
      
    node = childSelected;
    edges = graph->getOutGoingEdges(node);
  }
  return(turns);
}

int getNextState(int curr_state, int action)
{
  //Rcpp::Rcout << "curr_state=" << curr_state << ", action=" << action << ", last_turn=" << last_turn << std::endl;
  int new_state = -1;
  if (action == 4 || action == 5)
  {
    new_state = curr_state;
  }
  else if (curr_state == 0)
  {
    new_state = 1;
  }
  else if (curr_state == 1)
  {
    new_state = 0;
  }
  
  //Rcpp::Rcout << "new_state=" << new_state << std::endl;
  
  return (new_state);
}

double simulateTurnDuration(arma::mat hybridTurnTimes, int hybridTurnId, int state, int session, Strategy& strategy)
{

  //std::vector<int> turnStages = {0,totalPaths/4,totalPaths};
  std::string strategy_name = strategy.getName();
  int start = -1;
  int end = 0;
  int changepoint_ses = 10; //Rough assumption that durations stabilize after 10 sessions, not related to changepoint in the EM inference.
  arma::uvec indices = arma::find(hybridTurnTimes.col(4) > changepoint_ses, 1, "first");

  if(session < changepoint_ses)
  {
    start = 0;
    end = indices(0) - 1;
  }
  else
  {
    start = indices(0);
    end = hybridTurnTimes.n_rows-1;
  }

  //Rcpp::Rcout << "start=" << start << ", end=" << end << std::endl;
  
  arma::mat turnTimesMat_stage = hybridTurnTimes.rows(start,end);
  arma::vec turnDurations_stage = turnTimesMat_stage.col(5);

  double turnId = hybridTurnId;

  if(!strategy.getOptimal())
  {
    if(hybridTurnId == 1)
    {
      state = 0;
      hybridTurnId = 1;
    }else if(hybridTurnId == 2)
    {
      state = 0;
      hybridTurnId = 2;
    }else if(hybridTurnId == 3)
    {
      state = 0;
      hybridTurnId = 3;
    }else if(hybridTurnId == 4)
    {
      state = 0;
      hybridTurnId = 5;
    }else if(hybridTurnId == 5)
    {
      state = 0;
      hybridTurnId = 6;
    }else if(hybridTurnId == 6)
    {
      state = 0;
      hybridTurnId = 7;
    }else if(hybridTurnId == 7)
    {
      state = 1;
      hybridTurnId = 8;
    }else if(hybridTurnId == 9)
    {
      state = 1;
      hybridTurnId = 1;
    }else if(hybridTurnId == 10)
    {
      state = 1;
      hybridTurnId = 2;
    }else if(hybridTurnId == 11)
    {
      state = 1;
      hybridTurnId = 3;
    }
  }


  // Get all turn ids from turnTimesMat belonging to current turnStage
  arma::uvec arma_idx = arma::find(turnTimesMat_stage.col(3) == hybridTurnId && turnTimesMat_stage.col(2) == state);
  
  double hybridTurnDuration = 0;
  
  if(arma_idx.size() < 5)
  {
    //std::cout << "hybridTurnId=" << hybridTurnId << ", state=" << state << ", arma_idx.size()=" << arma_idx.size() << std::endl;

    if(hybridTurnId == 2 && state == 0){

      //hybridTurnId = 7;
      if(strategy_name == "aca2_Suboptimal_Hybrid3" || strategy_name == "aca2_Optimal_Hybrid3")
      {
        hybridTurnDuration = 100;
      }else{
        hybridTurnDuration = 50000;
      }


    }else if((hybridTurnId == 4 || hybridTurnId == 5|| hybridTurnId == 6 ) && state == 0)
    {
      //hybridTurnId = 3;
      if(strategy_name == "aca2_Suboptimal_Hybrid3" || strategy_name == "aca2_Optimal_Hybrid3")
      {
        hybridTurnDuration = 100;
      }else{
        hybridTurnDuration = 50000;
      }
    }else if(hybridTurnId == 2 && state == 1){

      //hybridTurnId = 7;
      if(strategy_name == "aca2_Suboptimal_Hybrid3" || strategy_name == "aca2_Optimal_Hybrid3")
      {
        hybridTurnDuration = 100;
      }else{
        hybridTurnDuration = 50000;
      }

    }else if((hybridTurnId == 4 || hybridTurnId == 5|| hybridTurnId == 6 ) && state == 1)
    {
      //hybridTurnId = 3;
      if(strategy_name == "aca2_Suboptimal_Hybrid3" || strategy_name == "aca2_Optimal_Hybrid3")
      {
        hybridTurnDuration = 100;
      }else{
        hybridTurnDuration = 50000;
      }
    }else if((hybridTurnId == 8) && state == 0) // special case rat_113
    {
      //hybridTurnId = 7;
      hybridTurnDuration = 4000; //update later
    }
    
    //arma_idx = arma::find(turnTimesMat_stage.col(3) == hybridTurnId && turnTimesMat_stage.col(2) == state); 
    //std::cout << "New hybridTurnId=" << hybridTurnId << ", state=" << state << ", arma_idx.size()=" << arma_idx.size() << std::endl;
  }else{
    arma::vec turnDurations_stage_turnid = turnDurations_stage.rows(arma_idx);
    double mean_value = arma::mean(turnDurations_stage_turnid);
    double std_deviation = arma::stddev(turnDurations_stage_turnid);
    double lambda = 1.0 / mean_value;

    std::random_device rd;
    std::default_random_engine generator(rd());
    std::exponential_distribution<> distribution(lambda);
  // mean=0, stddev=1

    hybridTurnDuration = distribution(generator);

  }  

  
  return(hybridTurnDuration);
}

