#include<iostream>
#include<fstream>
#include<vector>
#include <random>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <limits>
#include <string>
#include <chrono>
#include <iomanip>
#include <sstream>
// #include "valprob.h"
using namespace std;

typedef vector<int> VI;
typedef vector<VI> V2I;
typedef vector<V2I> V3I;
typedef vector<float> VD;
typedef vector<VD> V2D;
typedef vector<V2D> V3D;

// V2D VALPROB;


float energy_function_polynomial(float x, float alpha, float kappa, float beta) {
    return alpha*pow(abs(x-kappa),beta);
}

float energy_function_exponential(float x, float alpha, float kappa, float beta) {
    return alpha*(1.0 - exp(-beta*(x-kappa)*(x-kappa)));
}

float energy_function_monotonic(float x, float r, float L, float U, float p) {
    int m = 2;
    float alpha_r = (1-r)/r;
    if (p < L) {
        float kappa = (1.0+U)/2;
        float a_r = (1.0-r)*L + r*U;
        float C_r = (a_r - p)/(1.0 - p);
        if (x < a_r) { 
            return C_r + (1-C_r)*(1-exp((x-a_r)/alpha_r));
        }
        if (x > kappa) {
            return 1-exp(-pow((x-kappa)/alpha_r, m));
        }
        return C_r*pow(1-(x-a_r)/(kappa-a_r), alpha_r);
    }
    if (p > U) {
        float kappa = (L)/2;
        float a_r = (r)*L + (1.0-r)*U;
        float C_r = (p-a_r)/(p);
        if (x < kappa) {
            return 1-exp(-pow((x-kappa)/alpha_r, m));
        }
        if (x > a_r) {
            return C_r + (1-C_r)*(1-exp((a_r-x)/alpha_r));
        }
        return C_r*pow(1-(x-a_r)/(kappa-a_r), alpha_r);
    }
    float beta_r = 2;
    float kappa = p;
    float res = (1/alpha_r)*pow(abs(x-kappa),beta_r);
    if (res > 1.0) return 1.0;
    return res;
}

/*
Parameters: 
- L, U: lower and upper parameters of target interval 
- p, coin bias
- tau: first time things count
- T: final time horizon
- zeta: energy function

ValProb(t,c) = prob to ever leave the target interval when you have seen t coins, with c ones

*/


// New class that encapsulates the probability computation with memoization,
// and releases memory after compute() finishes.
class ValProbComputer {
public:
    float compute(int T, int tau, float p, float limitL, float limitU, float safetyL, float safetyU, int energy_type) {
        T_ = T; tau_ = tau; p_ = p;
        limitL_ = limitL; limitU_ = limitU;
        safetyL_ = safetyL; safetyU_ = safetyU; 
        energy_type_ = energy_type;
        // r_ = r;
        if (p_ < limitL_) kappa_ = (safetyU_+limitU_)/2.0;
        else if (p_ > limitU_) kappa_ = (safetyL_+limitL_)/2.0;
        else kappa_ = p_;
        memo_ = V2D(T_ + 1, VD(T_ + 1, -1));
        float ans = valProb(0, 0);
        // Cleanup memo to free memory before returning
        V2D empty;
        memo_.swap(empty); // releases capacity
        return ans;
    }
private:
    // Parameters
    int T_ = 0, tau_ = 0;
    float p_ = 0.0f, r_ = 0.5f;
    float safetyL_ = 0.0f, safetyU_ = 1.0f, limitL_ = 0.0f, limitU_ = 1.0f;
    float kappa_ = 0.5f;
    int energy_type_ = 0;
    // Memo table
    V2D memo_;

    float zeta_local(float x)  { 
        // return (x - kappa_) * (x - kappa_);
        // return energy_function_monotonic(x, r_, limitL_, limitU_, p_);
        if (energy_type_ == 1) {
            float kappa = 0.4; float alpha = 2.7; float beta = 2;
            float aux = energy_function_polynomial(x,alpha, kappa, beta);
            if (aux > 1) return 1;
            return aux;
        }
        if (energy_type_ == 2) {
            float kappa = 0.4; float alpha = 1; float beta = 128;
            float aux = energy_function_exponential(x,alpha, kappa, beta);
            if (aux > 1) return 1;
            return aux;
        }
        if (energy_type_ == 3) {
            if ((x <= 0.4) || (x >= 0.6)) return 1;
            return 0;
        }
        return 0.0;
    }

    float valProb(int t, int c) {
        float &res = memo_[t][c];
        if (res != -1) return res;
        if (t > tau_) {
            if ((c < t * safetyL_) || (c > t * safetyU_)) return res = 1.0f;
        }
        if (t == T_) return res = 0.0f;
        float energy = t > 0 ? zeta_local(float(c) / float(t)) : 0.5f;
        float f = c < kappa_ * t ? p_ + (1 - p_) * energy : p_ * (1 - energy);
        return res = f * valProb(t + 1, c + 1) + (1 - f) * valProb(t + 1, c);
    }
};


class ValExpComputer {
public:
    float compute(int T, int tau, float p, float limitL, float limitU, float safetyL, float safetyU, float r) {
        T_ = T; tau_ = tau; p_ = p; 
        limitL_ = limitL; limitU_ = limitU;
        safetyL_ = safetyL; safetyU_ = safetyU; 
        r_ = r;
        if (p_ < limitL_) kappa_ = (safetyU_+limitU_)/2.0;
        else if (p_ > limitU_) kappa_ = (safetyL_+limitL_)/2.0;
        else kappa_ = p_;
        memo_ = V2D(T_ + 1, VD(T_ + 1, -1));
        float ans = valExp(0, 0);
        // Cleanup memo to free memory before returning
        V2D empty;
        memo_.swap(empty); // releases capacity
        return ans;
    }
private:
    // Parameters
    int T_ = 0, tau_ = 0;
    float p_ = 0.0f, r_ = 0.5f;
    float safetyL_ = 0.0f, safetyU_ = 1.0f, limitL_ = 0.0f, limitU_ = 1.0f;
    float kappa_=0.5f;
    // Memo table
    V2D memo_;

    float zeta_local(float x)  { 
        // return (x - kappa_) * (x - kappa_); 
        return energy_function_monotonic(x, r_, limitL_, limitU_, p_);
    }

    float valExp(int t, int c) {
        float &res = memo_[t][c];
        if (res != -1) return res;
        float gamma = 0;
        if (t > tau_) {
            if ((c < t * safetyL_) || (c > t * safetyU_)) gamma=1;
        }
        if (t == T_) return res = gamma;
        float energy = t > 0 ? zeta_local(float(c) / float(t)) : 0.5f;
        float f = c < kappa_ * t ? p_ + (1 - p_) * energy : p_ * (1 - energy);
        return res = gamma + f * valExp(t + 1, c + 1) + (1 - f) * valExp(t + 1, c);
    }
};

// Provide free-function wrappers for other files to call
float compute_val_prob(int T, int tau, float p, float limitL, float limitU, float safetyL, float safetyU, float r) {
    ValProbComputer comp;
    return comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, r);
}

float compute_val_exp(int T, int tau, float p, float limitL, float limitU, float safetyL, float safetyU, float r) {
    ValExpComputer comp;
    return comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, r);
}



float bound_prob_T(float L, float U, float mustar, int T) {
    float K = 1.0/32.0;
    return exp(-K*T*pow(L-mustar,2)) + exp(-K*T*pow(U-mustar,2));
}

float bound_prob_interval(float L, float U, float mustar, int tau, int T) {
    float res = 0;
    for (int t = tau; t<= T; ++t) res += bound_prob_T(L, U, mustar, t);
    return res;
}

float bound_prob_tail(float L, float U, float mustar, int T) {
    float K = 1.0/32.0;
    float rL = exp(-K*pow(L-mustar,2));
    float rU = exp(-K*pow(U-mustar,2));
    return pow(rL, T)/(1.0-rL) + pow(rU, T)/(1.0-rU);
}

void experiment_paper_bounds_v_exact(VI vecTau, std::ostream& out) {
    int T = 15000;
    float p = 0.65;
    float safetyL = 0.3; float safetyU = 0.7;
    float limitL = 0.45; float limitU = 0.55;
    
    // float r = 0.2;
    for (int energy_type=1; energy_type <= 4; ++energy_type) {
        for (int i = 0; i < vecTau.size(); ++i) {
            int tau = vecTau[i];
            ValProbComputer comp;
            auto t0 = std::chrono::steady_clock::now();
            float prob_exact = comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, energy_type);
            auto t1 = std::chrono::steady_clock::now();
            double elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(t1 - t0).count();
            // float mustar = (L+U)/2.0;
            float mustar = 0.5;
            if (energy_type == 1) mustar = 0.47;
            if (energy_type == 2) mustar = 0.58;
            if (energy_type == 3) mustar = 0.6;
            if (energy_type == 4) mustar = 0.65; // these numbers are from the figure in the paper
            float prob_bound_tail = bound_prob_interval(safetyL, safetyU, mustar, tau, T);
            cout.setf(ios::fixed);
            cout << "r=" << energy_type << "T=" << T << ", tau=" << tau << ", Exact prob: " << prob_exact << ", Prob Bound: " << prob_bound_tail << " (computed in " << elapsed << " seconds)" << endl;
            // write CSV row: experiment,param,exact_prob,bound,elapsed_sec
            out.setf(ios::fixed);
            out << energy_type << "," << T << "," << tau << "," << prob_exact << "," << prob_bound_tail << "," << elapsed << "\n";
        }
    }
}

void experiment_paper_times(VI vecT, std::ostream& out) {
    int tau = 100;
    float p = 0.65;
    float safetyL = 0.3; float safetyU = 0.7;
    float limitL = 0.45; float limitU = 0.55;
    // float r = 0.35;
    for (int energy_type=1; energy_type <= 4; ++energy_type) {
        for (int i = 0; i < vecT.size(); ++i) {
            int T = vecT[i];
            ValProbComputer comp;
            auto t0 = std::chrono::steady_clock::now();
            float prob_exact = comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, energy_type);
            auto t1 = std::chrono::steady_clock::now();
            double elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(t1 - t0).count();
            // float mustar = (L+U)/2.0;
            float mustar = 0.5;
            if (energy_type == 1) mustar = 0.47;
            if (energy_type == 2) mustar = 0.58;
            if (energy_type == 3) mustar = 0.6;
            if (energy_type == 4) mustar = 0.65; // these numbers are from the figure in the paper
            float prob_bound_tail = bound_prob_tail(safetyL, safetyU, mustar, T);
            cout.setf(ios::fixed); 
            cout << "r=" << energy_type << "T=" << T << ", Exact prob: " << prob_exact << ", Tail Bound: " << prob_bound_tail << " (computed in " << elapsed << " seconds)" << endl;
            // write CSV row: experiment,param,exact_prob,bound,elapsed_sec
            out.setf(ios::fixed);
            out << energy_type << "," << T << "," << tau << "," << prob_exact << "," << prob_bound_tail << "," << elapsed << "\n";
            // cout << "Prob: " << prob << " (computed in " << elapsed << " seconds)" << endl;
        }
    }
}

// #ifdef VALPROB_STANDALONE
int main() {

    VI vecT;
    for (int t = 1000; t <= 15000; t += 1000) {
        vecT.push_back(t);
    }

    // Open output file with timestamp
    auto now = std::chrono::system_clock::now();
    std::time_t tnow = std::chrono::system_clock::to_time_t(now);
    std::ostringstream fname_times;
    fname_times << "prob_results_times_" << std::put_time(std::localtime(&tnow), "%Y%m%d_%H%M%S") << ".csv";
    std::ofstream ofs_times(fname_times.str());
    if (!ofs_times) {
        cerr << "Failed to open output file: " << fname_times.str() << "\n";
        return 1;
    }
    // CSV header
    ofs_times << "etype,T,tau,exact_prob,tail_bound,elapsed_sec\n";

    experiment_paper_times(vecT, ofs_times);
    ofs_times.close();


    now = std::chrono::system_clock::now();
    tnow = std::chrono::system_clock::to_time_t(now);
    std::ostringstream fname_bound;
    fname_bound << "prob_results_bound_" << std::put_time(std::localtime(&tnow), "%Y%m%d_%H%M%S") << ".csv";
    std::ofstream ofs_bound(fname_bound.str());
    if (!ofs_bound) {
        cerr << "Failed to open output file: " << fname_bound.str() << "\n";
        return 1;
    }
    // CSV header
    ofs_bound << "etype,T,tau,exact_prob,interval_bound,elapsed_sec\n";

    VI vecTau;
    for (int t = 100; t <= 14000; t+= 1000) vecTau.push_back(t);
    experiment_paper_bounds_v_exact(vecTau, ofs_bound);
    ofs_bound.close();
    // int T, tau;
    // float p;
    // // float kappa = 0.3;
    // float L, U;
    // float r;
    // cin >> T >> tau >> p >> L >> U >> r;
    // cout << "read: T=" << T << ", tau=" << tau << ", p=" << p << ", L=" << L << ", U=" << U << endl;
    // ValProbComputer comp;
    // auto t0 = std::chrono::steady_clock::now();
    // float prob = comp.compute(T, tau, p, L, U, r);
    // auto t1 = std::chrono::steady_clock::now();
    // double elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(t1 - t0).count();
    // cout.setf(ios::fixed); cout<<setprecision(6);
    // cout << "Prob: " << prob << " (computed in " << elapsed << " seconds)" << endl;
}