#include<fstream>
#include<vector>
#include<iostream>
#include <random>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <limits>
#include <string>
#include <chrono>
#include <iomanip>
#include <sstream>
// #include "valprob.h"
using namespace std;

typedef vector<int> VI;
typedef vector<VI> V2I;
typedef vector<V2I> V3I;
typedef vector<double> VD;
typedef vector<VD> V2D;
typedef vector<V2D> V3D;

// V2D VALPROB;


double energy_function_polynomial(double x, double alpha, double kappa, double beta) {
    return alpha*pow(abs(x-kappa),beta);
}

double energy_function_exponential(double x, double alpha, double kappa, double beta) {
    return alpha*(1.0 - exp(-beta*(x-kappa)*(x-kappa)));
}

double energy_function_monotonic(double x, double r, double L, double U, double p) {
    int m = 2;
    double alpha_r = (1-r)/r;
    if (p < L) {
        double kappa = (1.0+U)/2;
        double a_r = (1.0-r)*L + r*U;
        double C_r = (a_r - p)/(1.0 - p);
        if (x < a_r) { 
            return C_r + (1-C_r)*(1-exp((x-a_r)/alpha_r));
        }
        if (x > kappa) {
            return 1-exp(-pow((x-kappa)/alpha_r, m));
        }
        return C_r*pow(1-(x-a_r)/(kappa-a_r), alpha_r);
    }
    if (p > U) {
        double kappa = (L)/2;
        double a_r = (r)*L + (1.0-r)*U;
        double C_r = (p-a_r)/(p);
        if (x < kappa) {
            return 1-exp(-pow((x-kappa)/alpha_r, m));
        }
        if (x > a_r) {
            return C_r + (1-C_r)*(1-exp((a_r-x)/alpha_r));
        }
        return C_r*pow(1-(x-a_r)/(kappa-a_r), alpha_r);
    }
    double beta_r = 2;
    double kappa = p;
    double res = (1/alpha_r)*pow(abs(x-kappa),beta_r);
    if (res > 1.0) return 1.0;
    return res;
}


class ValProbComputer {
public:
    double compute(int T, int tau, double p, double limitL, double limitU, double safetyL, double safetyU, double r) {
        T_ = T; tau_ = tau; p_ = p;
        limitL_ = limitL; limitU_ = limitU;
        safetyL_ = safetyL; safetyU_ = safetyU; 
        r_ = r;
        if (p_ < limitL_) kappa_ = (safetyU_+limitU_)/2.0;
        else if (p_ > limitU_) kappa_ = (safetyL_+limitL_)/2.0;
        else kappa_ = p_;
        memo_ = V2D(T_ + 1, VD(T_ + 1, -1));
        double ans = valProb(0, 0);
        // Cleanup memo to free memory before returning
        V2D empty;
        memo_.swap(empty); // releases capacity
        return ans;
    }
private:
    // Parameters
    int T_ = 0, tau_ = 0;
    double p_ = 0.0f, r_ = 0.5f;
    double safetyL_ = 0.0f, safetyU_ = 1.0f, limitL_ = 0.0f, limitU_ = 1.0f;
    double kappa_ = 0.5f;
    // Memo table
    V2D memo_;

    double zeta_local(double x)  { 
        // return (x - kappa_) * (x - kappa_);
        return energy_function_monotonic(x, r_, limitL_, limitU_, p_);
    }

    double valProb(int t, int c) {
        double &res = memo_[t][c];
        if (res != -1) return res;
        if (t > tau_) {
            if ((c < t * safetyL_) || (c > t * safetyU_)) return res = 1.0f;
        }
        if (t == T_) return res = 0.0f;
        double energy = t > 0 ? zeta_local(double(c) / double(t)) : 0.5f;
        double f = c < kappa_ * t ? p_ + (1 - p_) * energy : p_ * (1 - energy);
        return res = f * valProb(t + 1, c + 1) + (1 - f) * valProb(t + 1, c);
    }
};


class ValExpComputer {
public:
    double compute(int T, int tau, double p, double limitL, double limitU, double safetyL, double safetyU, double r) {
        T_ = T; tau_ = tau; p_ = p; 
        limitL_ = limitL; limitU_ = limitU;
        safetyL_ = safetyL; safetyU_ = safetyU; 
        r_ = r;
        if (p_ < limitL_) kappa_ = (safetyU_+limitU_)/2.0;
        else if (p_ > limitU_) kappa_ = (safetyL_+limitL_)/2.0;
        else kappa_ = p_;
        memo_ = V2D(T_ + 1, VD(T_ + 1, -1));
        double ans = valExp(0, 0);
        // Cleanup memo to free memory before returning
        V2D empty;
        memo_.swap(empty); // releases capacity
        return ans;
    }
private:
    // Parameters
    int T_ = 0, tau_ = 0;
    double p_ = 0.0f, r_ = 0.5f;
    double safetyL_ = 0.0f, safetyU_ = 1.0f, limitL_ = 0.0f, limitU_ = 1.0f;
    double kappa_=0.5f;
    // Memo table
    V2D memo_;

    double zeta_local(double x)  { 
        // return (x - kappa_) * (x - kappa_); 
        return energy_function_monotonic(x, r_, limitL_, limitU_, p_);
    }

    double valExp(int t, int c) {
        double &res = memo_[t][c];
        if (res != -1) return res;
        double gamma = 0;
        if (t > tau_) {
            if ((c < t * safetyL_) || (c > t * safetyU_)) gamma=1;
        }
        if (t == T_) return res = gamma;
        double energy = t > 0 ? zeta_local(double(c) / double(t)) : 0.5f;
        double f = c < kappa_ * t ? p_ + (1 - p_) * energy : p_ * (1 - energy);
        return res = gamma + f * valExp(t + 1, c + 1) + (1 - f) * valExp(t + 1, c);
    }
};


int find_optimal_tailtime(double rminus, double rplus, double epsilon) {
    int lo = 10;
    int hi = 10;
    while (pow(rminus,hi)/(1-rminus) + pow(rplus,hi)/(1-rplus) > epsilon) {
        lo = hi;
        hi *= 2;
    }
    while (lo < hi) {
        int mid = lo + (hi - lo) / 2;
        if (pow(rminus,mid)/(1-rminus) + pow(rplus,mid)/(1-rplus) <= epsilon) {
            hi = mid;        // mid is good; answer is in [lo, mid]
        } else {
            lo = mid + 1;    // mid is bad; answer is in (mid, hi]
        }
    }
    return lo;
}

double synthesize_shield(int tau, double p, double safetyL, double safetyU, double limitL, double limitU, double delta, double epsilon) {
    double mustar = (limitL + limitU)/2;
    double K = 1.0/32.0;
    double rminus = exp(-K*(safetyL-mustar)*(safetyL-mustar));
    double rplus = exp(-K*(safetyU-mustar)*(safetyU-mustar));
    int T = find_optimal_tailtime(rminus, rplus, epsilon);
    cout << "T: " << T << endl;

    ValProbComputer comp;
    // return comp.compute(T, tau, p, L, U, r);
    double r_l = 0.001;
    double r_u = 0.999;
    if (comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, r_l) < delta) {
        cout << comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, r_l) << endl;
        return r_l;
    }
    if (comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, r_u) > delta) {
        cout << "FAIL" << endl ; return 1.0;
    }
    double m = (r_l + r_u)/2.0;
    while (r_l < r_u) {
        m = (r_l + r_u)/2.0;
        double d = comp.compute(T, tau, p, limitL, limitU, safetyL, safetyU, m);
        if (abs(d-delta) < epsilon) return m;
        if (d <= delta) r_u = m;
        else r_l = m;
        cout << "d: " << d << ", m: " << m << endl;

    }
    return m;
}

void synthesis_experiments_paper() {
    // double safetyL=0.2, safetyU=0.8;
    // double limitL=0.4, limitU=0.6;
    // double p=0.7, delta=0.01, epsilon;
    
    double safetyL = 0.3; double safetyU = 0.7;
    double limitL = 0.45; double limitU = 0.55;
    
    int tau=20;
    double delta = 0.1;
    double epsilon;
    double p = 0.65;
    // for (epsilon = 0.1; epsilon > 1e-10; epsilon *= 0.1) {
    //     auto t0 = std::chrono::steady_clock::now();
    //     double m = synthesize_shield(tau, p, safetyL, safetyU, limitL, limitU, delta, epsilon);
    //     auto t1 = std::chrono::steady_clock::now();
    //     double elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(t1 - t0).count();
    //     cout << "epsilon=" << epsilon << ", elapsed_time=" << elapsed << endl;
    // }


    // Open CSV file for writing
    ofstream csv("synthesis_results.csv");
    csv << "p,epsilon,elapsed_time\n"; // header row
    for (p=0.3; p <= 0.7; p+=0.1) {
        for (epsilon = 0.1; epsilon > 1e-8; epsilon *= 0.1) {
            auto t0 = chrono::steady_clock::now();
            double m = synthesize_shield(tau, p, safetyL, safetyU, limitL, limitU, delta, epsilon);
            auto t1 = chrono::steady_clock::now();
            double elapsed = chrono::duration_cast<chrono::duration<double>>(t1 - t0).count();

            // Print to console
            cout << "epsilon=" << epsilon << ", elapsed_time=" << elapsed << endl;

            // Write to CSV
            csv << p << "," << scientific << epsilon << "," << fixed << setprecision(6) << elapsed << "\n";
        }
    }

    csv.close(); // not strictly needed, but explicit
}




int main() {
    synthesis_experiments_paper();
    return 0;
    double safetyL, safetyU;
    double limitL, limitU;
    double p, delta, epsilon;
    int tau;
    cin >> tau >> p >> safetyL >> safetyU >> limitL >> limitU >> delta >> epsilon;
    double m = synthesize_shield(tau, p, safetyL, safetyU, limitL, limitU, delta, epsilon);
    cout << "Val: " << m << endl;
}