// Particle filter algorithm for expert assignment problem
// Assume that there are N particles, M experts, and T trials
// Each particle m is a vector of length T, where m[t] is the index of the expert assigned to trial t
// Each expert i has a probability distribution p_i over the possible outcomes of each trial
// The likelihood of a particle m given the observed outcomes y is the product of p_mt for all t
// The posterior probability of a particle m is proportional to its prior probability times its likelihood
// The prior probability of a particle m is assumed to be uniform over all possible assignments
// The resampling step is done using multinomial resampling with replacement

#include "ParticleFilter.h"
#include "Pagmoprob.h"
#include "BS_thread_pool.h"
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#include <pagmo/algorithm.hpp>
#include <pagmo/archipelago.hpp>
#include <pagmo/algorithms/nlopt.hpp>
#include <pagmo/algorithms/de.hpp>
#include <stdexcept>




using namespace std;



// A function that returns a random integer in the range [0, n-1]
int randint()
{
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dis(0, 1);
    return dis(gen);
}

// A function that returns a random sample from a discrete probability distribution
int sample(vector<double> p)
{
    std::random_device rd;
    std::mt19937 gen(rd());
    std::discrete_distribution<> dis(p.begin(), p.end());
    return dis(gen);
}

// A function that normalizes a vector of probabilities to sum to one
void normalizeRows(std::vector<std::vector<double>> &matrix)
{
    for (auto &row : matrix)
    {
        // Calculate the sum of elements in the current row
        double rowSum = 0.0;
        for (double element : row)
        {
            rowSum += element;
        }

        // Normalize the elements in the current row
        for (double &element : row)
        {
            element /= rowSum;
        }
    }
}

void normalizeRow(std::vector<std::vector<double>>& matrix, size_t t) {
    if (t < matrix.size() && !matrix[t].empty()) {
        // Get the sum of elements in the t-th row
        double rowSum = 0.0;
        for (double element : matrix[t]) {
            rowSum += element;
        }

        // Normalize the t-th row
        for (double& element : matrix[t]) {
            element /= rowSum;
        }
    } else {
        std::cerr << "Invalid row index or empty row." << std::endl;
    }
}


void normalize(std::vector<double> &p)
{

    bool hasNaN = std::any_of(p.begin(), p.end(),
                              [](double value)
                              { return std::isnan(value); });

    if (hasNaN)
    {
        throw std::runtime_error("nan weights before normalizing. Check");
    }
    double sum = 0;
    for (double x : p)
    {
        sum += x;
    }
    if (sum == 0)
    {

        // throw std::runtime_error("Error weight vec is zero");
        int N = p.size();
        double initWeight = 1.0 / (double)N;
        std::fill(p.begin(), p.end(), initWeight);
    }else
    {
        for (double &x : p)
        {
            x /= sum;
        }
    }
    

    hasNaN = std::any_of(p.begin(), p.end(),
                         [](double value)
                         { return std::isnan(value); });

    if (hasNaN)
    {
        throw std::runtime_error("nan weights after normalizing. Check");
    }
}

std::vector<double> systematicResampling(const std::vector<double> &particleWeights)
{
    int numParticles = particleWeights.size();
    std::vector<double> resampledParticles(numParticles, -1);

    // Compute cumulative weights
    std::vector<double> cumulativeWeights(numParticles);
    double totalWeight = 0.0;
    for (int i = 0; i < numParticles; ++i)
    {
        totalWeight += particleWeights[i];
        cumulativeWeights[i] = totalWeight;
    }

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<double> uniformDist(0.0, 1.0 / numParticles);
    double u = uniformDist(gen);

    // For each point, find the corresponding particle and select it
    for (int j = 0; j < numParticles; j++)
    {
        auto it = std::lower_bound(cumulativeWeights.begin(), cumulativeWeights.end(), u);
        int index = std::distance(cumulativeWeights.begin(), it);

        resampledParticles[j] = index;
        u += 1.0 / numParticles;
    }

    return resampledParticles;
}

std::vector<double> colwise_mean(const std::vector<std::vector<double>>& matrix) {
    std::size_t rows = matrix.size();
    std::size_t cols = matrix[0].size();  // Assuming all rows have the same number of columns

    std::vector<double> means(cols, 0.0);

    for (std::size_t j = 0; j < cols; ++j) {
        double col_sum = 0.0;

        for (std::size_t i = 0; i < rows; ++i) {
            col_sum += matrix[i][j];
        }

        means[j] = col_sum / rows;
    }

    return means;
}


// A function that prints a vector of doubles
void print_vector(vector<double> v)
{
    cout << "[";
    for (int i = 0; i < v.size(); i++)
    {
        cout << v[i];
        if (i < v.size() - 1)
        {
            cout << ", ";
        }
    }
    cout << "]" << endl;
}



double M_step5(const RatData &ratdata, const MazeGraph &Suboptimal_Hybrid3, const MazeGraph &Optimal_Hybrid3, std::vector<int> smoothedTrajectory, std::vector<double> params, BS::thread_pool& pool)
{
    // Extract session data from rat data
    arma::mat allpaths = ratdata.getPaths();
    arma::vec sessionVec = allpaths.col(4);
    arma::vec uniqSessIdx = arma::unique(sessionVec);
    int sessions = uniqSessIdx.n_elem;
    
    // Initialize a particle filter for the given parameters
    auto pf = ParticleFilter(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, params, 1, 1.0);
    
    // Initialize variables to accumulate likelihood values
    double I_1 = 0;
    double I_2 = 0;
    double I_3 = 0;

    std::vector<double> I2_vec;
    std::vector<double> I3_vec;
        
    double local_I_3 = 0.0;
    
    // Loop through each session to compute the log-likelihood
    for(int t = 0; t < sessions; t++)
    {
        // Calculate the likelihood of the smoothed trajectory at time t
        double lik_i = pf.getSesLikelihood(smoothedTrajectory[t], t);

        // Avoid zero likelihood to prevent log(0)
        if (lik_i == 0) {
            lik_i = 1e-6;
        }

        // Accumulate the log-likelihood
        local_I_3 += log(lik_i);

        // Check for numerical issues
        if (std::isnan(local_I_3) || std::isinf(local_I_3)) {
            throw std::runtime_error("Error in local_I_3 value");
        }
    }

    double local_I_2 = 0.0;
    
    // Loop through the session history to compute the CRP prior likelihoods
    std::vector<int> particleHistory_t = smoothedTrajectory;
    for(int t = 0; t < sessions - 2; t++)
    {
        std::vector<double> crp_t;
        
        try
        {
            // Calculate CRP priors for the particle history
            crp_t = pf.crpPrior2(particleHistory_t, t);
        }
        catch(const std::exception& e)
        {
            std::cout << "Error in loop_I2 in M_step5" << std::endl;
            std::cerr << e.what() << '\n';
        }

        // Extract the CRP prior probability for the next state in the trajectory
        double p_crp_X_i_tplus1 = crp_t[particleHistory_t[t + 1]];
        
        // Avoid zero probability to prevent log(0)
        if(p_crp_X_i_tplus1 == 0)
        {
            p_crp_X_i_tplus1 = std::numeric_limits<double>::min();
        }

        // Accumulate the log-likelihood of the CRP prior
        local_I_2 += log(p_crp_X_i_tplus1);

        // Check for numerical issues
        if (std::isnan(local_I_2)) {
            throw std::runtime_error("Error nan I_2 value");
        }

        if (std::isinf(local_I_2)) {
            throw std::runtime_error("Error inf I_2 value");
        }
    }

    // Calculate the total Q-value combining likelihood and CRP prior
    double Q_k = local_I_3 + local_I_2;

    // Return the calculated Q-value
    return Q_k;
}


// Estimate the strategies (hidden states) by computing the MAP of the smoothed distribution using Cpf-AS
std::vector<int> stateEstimation(const RatData &ratdata, const MazeGraph &Suboptimal_Hybrid3, const MazeGraph &Optimal_Hybrid3, int N, std::vector<double> params, int l_truncate, BS::thread_pool& pool)
{
    // Extracts the session data from rat data and initializes necessary variables
    arma::mat allpaths = ratdata.getPaths();
    arma::vec sessionVec = allpaths.col(4);
    arma::vec uniqSessIdx = arma::unique(sessionVec);
    int sessions = uniqSessIdx.n_elem;
    
    // Stores the sampled trajectories and strategy counts
    std::vector<std::vector<int>> sampledSmoothedTrajectories;
    std::vector<std::vector<double>> stratCounts(4, std::vector<double>(sessions, 0.0));
    
    // Initializes the conditional state vector for each session
    std::vector<int> x_cond(sessions, 0);

    // Run the particle filter and sampling loop 10,000 times
    for (int i = 0; i < 10000; i++)
    {
        // Initialize the vector of particle filters for each iteration
        std::vector<ParticleFilter> particleFilterVec;
        for (int k = 0; k < N; k++)
        {
            auto pf = ParticleFilter(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, params, k, 1.0);
            particleFilterVec.push_back(pf);
        }

        // Apply the conditional particle filter and smoothing
        auto [filteredWeights, loglik, smoothedTrajectories] = cpf_as(N, particleFilterVec, ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, x_cond, l_truncate, pool);

        // Sample a trajectory based on the filtered weights
        int sampled_trajectory = sample(filteredWeights[sessions-1]);
        x_cond = smoothedTrajectories[sampled_trajectory];

        // Store the smoothed trajectories after a burn-in period
        if(i >= 1000)
        {
            sampledSmoothedTrajectories.push_back(x_cond);
        }
    }

    std::cout << "Generated smoothed trajectories for " << ratdata.getRat() << std::endl;

    // Count the frequency of each strategy in the sampled trajectories
    for(int j = 0; j < sampledSmoothedTrajectories.size(); j++)
    {
       for(int t = 0; t < sessions; t++)
       {
            int strat_t_j = sampledSmoothedTrajectories[j][t];
            stratCounts[strat_t_j][t]++;
       } 
    }

    // Normalize the strategy counts to get probabilities
    for (auto& row : stratCounts) {
        for (auto& element : row) {
            element /= (double)sampledSmoothedTrajectories.size();
        }
    }

    // Print the normalized strategy probabilities for debugging
    for (const auto& row : stratCounts) {
        for (double element : row) {
            std::cout << element << " ";
        }
        std::cout << std::endl;
    }

    std::cout << "Smoothed map sequence: ";
    std::vector<int> map_seq;

    // Determine the most probable strategy for each session
    for(int t = 0; t < sessions; t++)
    {
        std::vector<double> stratProbs_t = {stratCounts[0][t], stratCounts[1][t], stratCounts[2][t], stratCounts[3][t]};
        auto max_it = std::max_element(stratProbs_t.begin(), stratProbs_t.end());
        size_t max_index = std::distance(stratProbs_t.begin(), max_it);

        // Sort the strategy probabilities to check if the difference is significant
        std::vector<double> sortedVec = stratProbs_t;
        std::sort(sortedVec.begin(), sortedVec.end(), std::greater<double>());

        // If the difference between the most probable and the next strategy is significant, select it
        if(sortedVec[0] - sortedVec[1] >= 0.1)
        {
            std::cout << max_index << ", "; 
            map_seq.push_back(max_index);
        }
        else
        {
            // If not, mark it as 'None' (indicated by -1)
            std::cout << " None,"; 
            map_seq.push_back(-1);
        }
    } 
    std::cout << std::endl;

    // Return the most probable sequence of strategies across all sessions
    return map_seq;
}


std::vector<double> SAEM(const RatData &ratdata, const MazeGraph &Suboptimal_Hybrid3, const MazeGraph &Optimal_Hybrid3, int N, BS::thread_pool& pool)
{
    // Initialize the parameters to be estimated
    std::vector<double> params = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
    std::vector<double> QFuncVals; // Stores the Q function values
    std::vector<std::vector<double>> params_iter; // Stores parameter values for each iteration
    double Q_prev = 0; // Initialize the previous Q value
    int M = 10; // Number of iterations for M-step
    double gamma = 0.1; // Step size for updating the parameters
    std::vector<std::vector<int>> prevSmoothedTrajectories; // Stores smoothed trajectories from the previous iteration
    std::vector<std::vector<double>> prevFilteredWeights; // Stores filtered weights from the previous iteration

    // Extract the paths and session information from the data
    arma::mat allpaths = ratdata.getPaths();
    arma::vec sessionVec = allpaths.col(4);
    arma::vec uniqSessIdx = arma::unique(sessionVec); // Unique session indices
    int sessions = uniqSessIdx.n_elem; // Number of sessions
    std::vector<int> x_cond(sessions, 1); // Initial state conditions

    int l_truncate = 5; // Truncation length for the CPF-AS algorithm

    // Main loop for the SAEM algorithm
    for (int i = 0; i < 300; i++)
    {
        std::vector<ParticleFilter> particleFilterVec; // Vector to hold particle filters

        // Initialize N particles
        for (int i = 0; i < N; i++)
        {
            auto pf = ParticleFilter(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, params, i, 1.0);
            particleFilterVec.push_back(pf);
        }

        // Perform the CPF-AS step to get filtered weights, log-likelihood, and smoothed trajectories
        auto [filteredWeights, loglik, smoothedTrajectories] = cpf_as(N, particleFilterVec, ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, x_cond, l_truncate, pool);

        // Start the E-step after 100 iterations
        if(i >= 100)
        {
            std::cout << "i=" << i << ", E-step" << std::endl;

            // Initialize the Pagmo problem with the current data and parameters
            PagmoProb pagmoprob(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, N, i+1, gamma, smoothedTrajectories, filteredWeights, prevSmoothedTrajectories, prevFilteredWeights,  pool);
            std::cout << "Initialized problem class" << std::endl;

            // Create a Pagmo problem
            problem prob{pagmoprob};
            int count = 0;

            // Set up the optimization method using Pagmo
            pagmo::nlopt method("sbplx");
            method.set_xtol_abs(1e-3);

            // Initialize the optimization algorithm and population
            pagmo::algorithm algo = pagmo::algorithm{method};
            pagmo::population pop(prob, 30);
            pop = algo.evolve(pop); // Evolve the population

            // Get the best solution (champion) from the evolved population
            std::vector<double> dec_vec_champion = pop.champion_x();
            double champion = pop.champion_f()[0];
            std::cout << "Final champion = " << champion << std::endl;

            // Print the best parameter values (champion) obtained from the optimization
            std::cout << "dec_vec_champion: ";
            for (const auto &x : dec_vec_champion)
            {
                std::cout << std::setprecision(2) << std::scientific <<  x << " ";
            }
            std::cout << "\n";

            // Calculate the maximum stopping criteria to check for parameter convergence
            double maxStopCriteria = 0.0;
            for (size_t j = 0; j < params.size(); ++j) {
                double stopCriterion = std::abs(params[j] - dec_vec_champion[j])/(params[j]+0.01);
                if (stopCriterion > maxStopCriteria) {
                    maxStopCriteria = stopCriterion;
                }
            }
            std::cout << "i=" << i << ", max_stopping_criteria=" << maxStopCriteria << std::endl;

            // Calculate the relative log-likelihood for the current and previous parameter sets
            double relLogLik1 = 0;
            for(int k=0; k<N;k++)
            {
                double Q_k = M_step5(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, smoothedTrajectories[k], dec_vec_champion, pool);
                double Q_k_minus1 = M_step5(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3,smoothedTrajectories[k], params, pool);
                double ratio = Q_k/Q_k_minus1;
                relLogLik1 = relLogLik1 + ratio;
            }
            relLogLik1 = log(relLogLik1/N);
            std::cout << "relLogLik1=" << std::fixed << std::setprecision(6) << relLogLik1 << std::endl;

            // Calculate the relative log-likelihood for the current and previous (i-2) parameter sets
            double relLogLik2 = 0;
            for(int k=0; k<N;k++)
            {
                double Q_k = M_step5(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, smoothedTrajectories[k], dec_vec_champion, pool);
                double Q_k_minus2 = M_step5(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3,smoothedTrajectories[k], params_iter[i-2], pool);
                double ratio = Q_k/Q_k_minus2;
                relLogLik2 = relLogLik2 + ratio;
            }
            relLogLik2 = log(relLogLik2/N);
            std::cout << "relLogLik2=" << std::fixed << std::setprecision(6) << relLogLik2 << std::endl;

            // Update the parameters with the best solution (champion) found
            params = dec_vec_champion;

            // Update the filtering distribution based on the chosen strategy for each session
            std::vector<std::vector<double>> filteringDist(4,std::vector<double>(sessions));
            for (int ses = 0; ses < sessions; ses++)
            {
                for (int i = 0; i < N; i++)
                {
                    std::vector<int> chosenStrategy_pf = particleFilterVec[i].getOriginalSampledStrats();
                    filteringDist[chosenStrategy_pf[ses]][ses] = filteringDist[chosenStrategy_pf[ses]][ses] + filteredWeights[ses][i];
                }
            }

            // Print the filtering distribution
            std::cout << "filtering Dist=" << std::endl;
            for (const auto &row : filteringDist)
            {
                for (double num : row)
                {
                    std::cout << std::fixed << std::setprecision(2) << num << " ";
                }
                std::cout << std::endl;
            }

            // Check convergence criteria based on the relative log-likelihoods
            if(std::abs(relLogLik1) < 1e-5 && std::abs(relLogLik2) < 1e-5 && i > 120)
            {
                std::cout << "Terminate EM, likelihood converged after i=" << i  << std::endl;

                // Perform state estimation based on the final parameters
                std::vector<int>inferred_seq =  stateEstimation(ratdata, Suboptimal_Hybrid3, Optimal_Hybrid3, N, params, 5, pool);

                break;
            }
        }

        // Store the current smoothed trajectories and filtered weights for the next iteration
        prevSmoothedTrajectories = smoothedTrajectories;
        prevFilteredWeights = filteredWeights;

        // Save the current parameter set
        params_iter.push_back(params);

        // Sample a trajectory based on the filtered weights of the last session and update x_cond
        int sampled_trajectory = sample(filteredWeights[sessions-1]);
        x_cond = smoothedTrajectories[sampled_trajectory]; 
    }

    // Return the final estimated parameters
    return (params);
}
