#include "ParticleFilter.h"
#include "Pagmoprob.h"
#include "BS_thread_pool.h"
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>

// Function to compute the log-sum-exp of a vector of log-weights.
// This is a numerically stable way to calculate the logarithm of the sum of exponentials of the input values.
double log_sum_exp(const std::vector<double>& log_weights) {
    // Find the maximum log-weight to prevent overflow when exponentiating.
    double max_log_weight = *std::max_element(log_weights.begin(), log_weights.end());
    
    // Calculate the sum of exponentials of the differences between each log-weight and the maximum log-weight.
    double sum_exp_diff = 0.0;
    for (const double& lw : log_weights) {
        sum_exp_diff += std::exp(lw - max_log_weight);
    }
    
    // Return the log-sum-exp result by adding the maximum log-weight back to the log of the sum of exponentials.
    return max_log_weight + std::log(sum_exp_diff);
}

// Function to perform Conditional Particle Filtering (CPF) with ancestor sampling.
// This function returns a tuple containing the filtered weights, the log-likelihood, and the smoothed trajectories.
std::tuple<std::vector<std::vector<double>>, double, std::vector<std::vector<int>>> cpf_as(
    int N, std::vector<ParticleFilter>& particleFilterVec, const RatData& ratdata, 
    const MazeGraph& Suboptimal_Hybrid3, const MazeGraph& Optimal_Hybrid3, 
    std::vector<int> x_cond, int l_truncate, BS::thread_pool& pool)
{
    // Extract all paths and session indices from the rat data.
    arma::mat allpaths = ratdata.getPaths();
    arma::vec sessionVec = allpaths.col(4);
    arma::vec uniqSessIdx = arma::unique(sessionVec);
    int sessions = uniqSessIdx.n_elem;

    // Initialize particle weights uniformly.
    std::vector<double> w(N, 1.0 / (double) N);
    
    // Initialize data structures to store filtered weights and smoothed trajectories across sessions.
    std::vector<std::vector<double>> filteredWeights;
    std::vector<std::vector<int>> smoothedTrajectories;

    // Initialize variables to store delta and phi values for each session and particle.
    std::vector<std::vector<double>> delta(sessions, std::vector<double>(N, 0.0));
    std::vector<std::vector<double>> phi(sessions, std::vector<double>(N, 0.0));

    double loglik = 0;  // Variable to accumulate the log-likelihood.

    // Iterate over the sessions.
    for (int ses = 0; ses < sessions; ses++) {
        double ses_lik = 0;  // Variable to accumulate the session likelihood.

        if (ses == 0) {  // Initialize particles at the first session.
            // Submit a sequence of tasks to the thread pool to initialize the particles in parallel.
            BS::multi_future<void> initPf = pool.submit_sequence(0, N, 
                [&particleFilterVec, &delta, &phi, &ses, &N, &x_cond, &w](int i) {
                    // Define a uniform CRP (Chinese Restaurant Process) prior.
                    std::vector<double> crp_i = {0.25, 0.25, 0.25, 0.25};
                    std::vector<std::shared_ptr<Strategy>> strategies = particleFilterVec[i].getStrategies();
                    particleFilterVec[i].backUpStratCredits();  // Backup strategy credits.

                    // Calculate likelihoods for each strategy.
                    std::vector<double> likelihoods;
                    for (int k = 0; k < strategies.size(); k++) {
                        double lik = particleFilterVec[i].getSesLikelihood(k, ses);
                        if (lik == 0) {
                            lik = std::numeric_limits<double>::min();  // Prevent zero likelihood.
                        }
                        likelihoods.push_back(lik);
                    }

                    // Calculate the posterior probability of each strategy.
                    double sum = 0;
                    for (int k = 0; k < strategies.size(); k++) {
                        sum += crp_i[k] * likelihoods[k];
                    }    
                    std::vector<double> p;
                    for (int k = 0; k < strategies.size(); k++) {
                        p.push_back(crp_i[k] * likelihoods[k] / sum);
                    }

                    // Rollback credits after calculation.
                    particleFilterVec[i].rollBackCredits();

                    // Add CRP prior to the particle.
                    particleFilterVec[i].addCrpPrior(p, ses);

                    // Sample a strategy based on the calculated probabilities.
                    int sampled_strat = (i == N - 1) ? x_cond[0] : sample(p);

                    // Prevent zero probabilities.
                    if (p[sampled_strat] == 0) {
                        p[sampled_strat] = std::numeric_limits<double>::min();
                    }

                    // Update the particle with the sampled strategy and trajectory.
                    particleFilterVec[i].addAssignment(ses, sampled_strat);
                    particleFilterVec[i].addOriginalSampledStrat(ses, sampled_strat);
                    std::vector<int> particleHistory = particleFilterVec[i].getChosenStratgies();
                    particleFilterVec[i].addParticleTrajectory(particleHistory, ses);

                    // Calculate the likelihood of the chosen strategy and update the weight.
                    double lik = particleFilterVec[i].getSesLikelihood(sampled_strat, ses);
                    w[i] = lik * crp_i[sampled_strat] / p[sampled_strat];

                    // Update delta for the session.
                    delta[ses][i] = log(0.25) + log(lik);

                    particleFilterVec[i].backUpStratCredits();  // Backup strategy credits for the next session.
                }
            );

            initPf.wait();  // Wait for all particle initializations to complete.

            normalize(w);  // Normalize the particle weights.

            filteredWeights.push_back(w);  // Store the filtered weights for the session.

            continue;  // Move to the next session.
        }

        // Resampling step for sessions beyond the first.
        double weightSq = 0;
        for (int k = 0; k < N; k++) {
            weightSq += std::pow(w[k], 2);  // Calculate the squared weights.
        }
        double n_eff = 1 / weightSq;  // Calculate the effective sample size.

        if (1) {  // Always perform resampling.
            std::vector<double> resampledIndices = systematicResampling(w);  // Perform systematic resampling.

            // Backup chosen strategies and credits before resampling.
            for (int j = 0; j < N; j++) {
                particleFilterVec[j].backUpChosenStrategies();
                particleFilterVec[j].backUpStratCredits();
            }

            // Submit tasks to the thread pool to resample the particles in parallel.
            BS::multi_future<void> resample_particle_filter = pool.submit_sequence(0, N-1, 
                [&particleFilterVec, &ses, &resampledIndices](int i) {
                    int newIndex = resampledIndices[i];

                    // Backup the resampled particle's chosen strategies and strategy credits.
                    std::vector<std::pair<std::vector<double>, std::vector<double>>>& stratBackUps = particleFilterVec[newIndex].getStratCreditBackUps();
                    std::vector<signed int> chosenStrategyBackUp = particleFilterVec[newIndex].getChosenStrategyBackups();
                    
                    // Update the particle with the resampled strategies and credits.
                    particleFilterVec[i].setChosenStrategies(chosenStrategyBackUp);
                    particleFilterVec[i].setStratBackups(stratBackUps);
                }
            );

            resample_particle_filter.wait();  // Wait for all resampling tasks to complete.

            // Reset the particle weights to uniform after resampling.
            double initWeight = 1.0 / (double)N;
            std::fill(w.begin(), w.end(), initWeight);
        }

        // Ancestor sampling step.
        std::vector<double> ancestorProbs(N, 0.0);
        BS::multi_future<void> ancestorSampling = pool.submit_sequence(0, N, 
            [&particleFilterVec, &ses, &w, &x_cond, &sessions, &ancestorProbs, &l_truncate](int i) {
                particleFilterVec[i].backUpStratCredits();  // Backup strategy credits.

                std::vector<int> particleHistory_t_minus1 = particleFilterVec[i].getParticleTrajectories()[ses - 1];

                int l = std::min(sessions, (ses - 1 + l_truncate));
                double prod = log(1);
                for (int s = ses; s < l; s++) {
                    // Update particle history and calculate likelihoods for ancestor sampling.
                    std::vector<double> crp_i = particleFilterVec[i].crpPrior2(particleHistory_t_minus1, s - 1);
                    double lik = particleFilterVec[i].getSesLikelihood(x_cond[s], s);
                    prod += log(crp_i[x_cond[s]]) + log(lik);
                    particleHistory_t_minus1[s] = x_cond[s];
                }

                // Update the ancestor probabilities.
                ancestorProbs[i] = prod + log(w[i]);

                particleFilterVec[i].rollBackCredits();  // Rollback credits after calculation.
            }
        );

        ancestorSampling.wait();  // Wait for ancestor sampling tasks to complete.

        double lse = log_sum_exp(ancestorProbs);  // Compute the log-sum-exp of the ancestor probabilities.

        // Store the log-likelihood contribution for the session.
        ses_lik += lse;

        // Store the smoothed trajectory for each particle.
        for (int i = 0; i < N; i++) {
            smoothedTrajectories.push_back(particleFilterVec[i].getParticleTrajectories()[ses - 1]);
        }

        // Update the log-likelihood with the session's contribution.
        loglik += ses_lik;

        // Store the filtered weights for the session.
        filteredWeights.push_back(w);
    }

    // Return the filtered weights, log-likelihood, and smoothed trajectories as a tuple.
    return std::make_tuple(filteredWeights, loglik, smoothedTrajectories);
}

