#include "InferStrategy.h"
#include "Pagmoprob.h"
#include "Simulation.h"
#include <pagmo/algorithm.hpp>
#include <pagmo/archipelago.hpp>
#include <random>
#include <RInside.h>
#include <boost/archive/text_oarchive.hpp>
#include <boost/serialization/map.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/utility.hpp>
#include <stdexcept>





// Define a function to compute the EMA of rewards for a given state
arma::vec ema_rewards(arma::mat data, int state) {
  double alpha = 0.3;
  int n_rows = data.n_rows;
  arma::vec ema(n_rows, arma::fill::zeros);
  double ema_prev = 0;
  
  for (int i = 0; i < n_rows; i++) {
    int s = data(i, 1);
    double r = data(i, 2);
    
    // Cap the reward at 1
    r = (r > 1) ? 1 : r;
    
    double ema_curr = alpha * r + (1 - alpha) * ema_prev;
    ema(i) = ema_curr;
    ema_prev = ema_curr;
  }
  
  return ema;
}


// Define a function to check if the EMA is greater than or equal to a threshold for at least a consecutive_count number of rows
bool check_ema(arma::mat data, double threshold, int consecutive_count) {
  arma::mat dataS0 = data.rows(find(data.col(1) == 0));
  arma::mat dataS1 = data.rows(find(data.col(1) == 1));
  
  arma::vec ema0 = ema_rewards(dataS0, 0);
  arma::vec ema1 = ema_rewards(dataS1, 1);
  
  int count0 = 0;
  int count1 = 0;
  bool s0 = false;
  bool s1 = false;

      auto middleIteratorS0 = ema0.begin() + ema0.size() / 2;


    // CHECK1: S0 does not decay suddenly after learning in the second half of exp    
    // Count the values less than 0.5 in the second half
    int countS0 = std::count_if(middleIteratorS0, ema0.end(), [](double value) {
        return value < 0.4;
    });

    // Check if the count is greater than 50
    if (countS0 > 100) {
        std::cout << "check_ema failed. S0 prob less than 0.5 for more than 200 trials" <<std::endl;
        return false;

    }

    //CHECK 2: S0 is not equal to 1 for most of trials
    int countGreaterThan099 = std::count_if(ema0.begin(), ema0.end(), [](double value) {
        return value > 0.99;
    });

    // Check if the count is greater than 50
    if (countGreaterThan099/ema0.size() > 0.9) {
        std::cout << "check_ema failed.  S0 is close to 1 for most of trials" <<std::endl;
        return false;

    }

    
    //CHECK 3: S1 does not decay suddenly after learning in second half of exp
    auto middleIteratorS1 = ema1.begin() + ema1.size() / 2;

    // Count the values less than 0.5 in the second half
    int countS1 = std::count_if(middleIteratorS1, ema1.end(), [](double value) {
        return value < 0.6;
    });

    // Check if the count is greater than 50
    if (countS1 > 100) {
        std::cout << "check_ema failed.S1 prob less than 0.5 for more than 100 trials; countS1 = " << countS1 <<std::endl;
        return false;

    }
  
  for (int i = 0; i < ema0.n_elem; i++) {
    if (ema0(i) >= threshold) {
      count0++;
      if (count0 == consecutive_count) {
        s0 = true;
        break;
      }
    } else {
      count0 = 0;
    }
  }
  
  count1 = 0;
  
  for (int i = 0; i < ema1.n_elem; i++) {
    if (ema1(i) >= threshold) {
      count1++;
      if (count1 == consecutive_count) {
        s1 = true;
        break;
      }
    } else {
      count1 = 0;
    }
  }
  
  if (s0 && s1) {
    return true;
  }
  
  return false;
}



// Define a function to compute the EMA of occurrence of Path5 for a given state
arma::vec ema_path5(arma::mat data) {
  
  double alpha = 0.2;
  // Get the number of rows
  int n_rows = data.n_rows;
  // Initialize the EMA vector
  arma::vec ema(n_rows, arma::fill::zeros);
  // Initialize the previous EMA value
  double ema_prev = 0;
  // Loop over the rows
  for (int i = 0; i < n_rows; i++) {
    // Get the current path
    int p = data(i, 0);
    // Define a binary variable for Path5
    int p5 = (p == 4) ? 1 : 0;
    // Compute the current EMA value
    double ema_curr = alpha * p5 + (1 - alpha) * ema_prev;
    // Store the current EMA value in the vector
    ema(i) = ema_curr;
    // Update the previous EMA value
    ema_prev = ema_curr;
  }
  // Return the EMA vector
  return ema;
}


// Define a function to check if the EMA reaches 0.5 in either state
bool check_path5(arma::mat data) {
  // Filter data for state 0
  arma::mat dataS0 = data.rows(find(data.col(1) == 0));
  // Filter data for state 1
  arma::mat dataS1 = data.rows(find(data.col(1) == 1));
  
  // Compute the EMA of occurrence of Path5 for state 0
  arma::vec ema0 = ema_path5(dataS0);
  // Compute the EMA of occurrence of Path5 for state 1
  arma::vec ema1 = ema_path5(dataS1);
  
      // Check if any element in ema0 is greater than 0.8
    bool anyGreaterThanPointEight_ema0 = std::any_of(ema0.begin(), ema0.end(), [](double element) {
        return element > 0.95;
    });

    // Check if any element in ema1 is greater than 0.8
    bool anyGreaterThanPointEight_ema1 = std::any_of(ema1.begin(), ema1.end(), [](double element) {
        return element > 0.95;
    });

    // If path5 prob goes above 0.8 for any state, return false (bad simulation)
    if(anyGreaterThanPointEight_ema1 || anyGreaterThanPointEight_ema0)
    {
        std::cout << "check_path5 failed, anyGreaterThanPointEight_ema0: " << anyGreaterThanPointEight_ema0 << ", anyGreaterThanPointEight_ema1:" << anyGreaterThanPointEight_ema1 <<std::endl;
        return false;
    }

    //CHECK 3: S1 does not decay suddenly after learning in second half of exp
    auto middleIteratorS0 = ema0.begin() + ema0.size() / 2;

    // Count the values less than 0.5 in the second half
    int countS0 = std::count_if(ema0.begin(), middleIteratorS0, [](double value) {
        return value > 0.3;
    });

    // Check if count of S0 path5 prob above 0.3 is greater than 20
    if (countS0 > 20) {
        std::cout << "check_path5 failed." <<std::endl;
        return false;

    }


  int count1 = 0;
  bool S1Path5 = false;
  for (int i = 0; i < ema1.n_elem; i++) {
    if (ema1(i) >= 0.5) {
      count1++;
      if (count1 == 10) {
        S1Path5 = true;
        break;
      }
    } else {
      count1 = 0;
    }
  }

  if(S1Path5)
  {
    return true;
  }

  
  // Return false
  return false;
}


bool checkConsecutiveThreshold(arma::mat data, double threshold, int consecutiveCount, int changepoint) {

    arma::mat dataS0 = data.rows(find(data.col(1) == 0));
    arma::mat dataS1 = data.rows(find(data.col(1) == 1));
  
    arma::vec ema0 = ema_rewards(dataS0, 0);
    arma::vec ema1 = ema_rewards(dataS1, 1);

    bool S0condition = false;
    bool S1condition = false;

    int consecutiveRows = 0;

    // ema0 = ema0.subvec(changepoint, ema0.n_elem-1);
    // ema1 = ema1.subvec(changepoint, ema1.n_elem-1);
    
    for (unsigned int i = changepoint; i < ema0.n_elem; ++i) {
        if (ema0(i) < threshold) {
            consecutiveRows++;
            if (consecutiveRows >= consecutiveCount) {
                S0condition = true;
                break; // Exit loop if condition is met
            }
        } else {
            consecutiveRows = 0; // Reset consecutive count if the condition is not met
        }
    }


    consecutiveRows = 0;

    for (unsigned int i = changepoint; i < ema1.n_elem; ++i) {
        if (ema1(i) < threshold) {
            consecutiveRows++;
            if (consecutiveRows >= consecutiveCount) {
                S1condition = true;
                break; // Exit loop if condition is met
            }
        } else {
            consecutiveRows = 0; // Reset consecutive count if the condition is not met
        }
    }

    if(S0condition || S1condition)
    {
        return true;
    }

    // No consecutive rows meeting the condition
    return false;
}



