#include <iostream>
#include <fstream>
#include <string>
#include <cmath>
#include <boost/cstdfloat.hpp>
#include "mdp.h"
#include "catch.h"
#include "multimodel_mdp.h"
#include "mdp_examples.h"
#include "hist.h"
#include "utils.h"
#include "bamdp_solver.h"
#include "bamcp_solver.h"
#include "posterior_sampling_solver.h"
#include "hardcoded_solver.h"
#include "bamcp_maxprob_solver.h"
#include "bamdp_cvar_solver.h"
#include "bamcp_threshold_solver.h"
#include "mcts_cvar_sg.h"
#include "mcts_cvar_sg_offline.h"
#include "finite_mdp_belief.h"
#include "bamdp_cvar_decision_node.h"
#include "mcts_bamdp_cvar_sg.h"
#include "cvar_value_iteration.h"
#include "value_iteration.h"
#include "domains.h"
#include "run_algorithms.h"
#include "cvar_expected_mdp_policy.h"
#include "agent_expected_mdp_policy.h"

#include "known_polytope_generators.h"
#include "random_walks/random_walks.hpp"

#include "volume/volume_sequence_of_balls.hpp"
#include "volume/volume_cooling_gaussians.hpp"
#include "volume/volume_cooling_balls.hpp"
#include "sampling/sampling.hpp"

void pg_tabular(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    int evalRepeats,
    std::vector<int> evalTrials,
    float lr = 0.001
){
    int batchSize = 1000;

    runPGTabular(
        folder,
        dom,
        alphas,
        evalRepeats,
        evalTrials,
        batchSize,
        lr
    );
}

void pg_approx(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    int evalRepeats,
    std::vector<int> evalTrials,
    float lr = 0.01,
    bool initCvarPolicy = true
){
    int batchSize = 1000;

    runPGApprox(
        folder,
        dom,
        alphas,
        evalRepeats,
        evalTrials,
        batchSize,
        lr,
        initCvarPolicy
    );
}


void mcts_offline(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    int evalRepeats,
    std::vector<int> evalTrials,
    std::string strat = "bayesOpt",
    std::string optim="worst_case"
){
    float bias = 2.0;
    float widening = 0.2;
    int batchSize = 1000;

    runMCTSOffline(
        folder,
        dom,
        alphas,
        evalRepeats,
        evalTrials,
        batchSize,
        bias,
        widening,
        strat,
        optim
    );
}

void mcts_online(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    float widening,
    int evalRepeats,
    std::vector<std::tuple<int, int, int>> trialsList,
    std::string strat = "bayesOpt",
    std::string optim="worst_case"
){
    float bias = 2.0;
    std::string rolloutPolName = "none";

    runMCTS(
        folder,
        dom,
        alphas,
        evalRepeats,
        trialsList,
        bias,
        widening,
        strat,
        optim,
        rolloutPolName
    );
}

void mcts_online_cvar_rollout_pol(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    float widening,
    int evalRepeats,
    std::vector<std::tuple<int, int, int>> trialsList,
    std::string strat = "bayesOpt",
    std::string optim ="worst_case"
){
    float bias = 2.0;
    std::string rolloutPolName = "cvar";

    runMCTS(
        folder,
        dom,
        alphas,
        evalRepeats,
        trialsList,
        bias,
        widening,
        strat,
        optim,
        rolloutPolName
    );
}

void mcts_online_expected_rollout_pol(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    float widening,
    int evalRepeats,
    std::vector<std::tuple<int, int, int>> trialsList,
    std::string strat = "bayesOpt",
    std::string optim="worst_case"
){
    float bias = 2.0;
    std::string rolloutPolName = "expected_value";

    runMCTS(
        folder,
        dom,
        alphas,
        evalRepeats,
        trialsList,
        bias,
        widening,
        strat,
        optim,
        rolloutPolName
    );
}

void cvar_vi_bamdp(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    int evalRepeats,
    int numInterpPts=20
){
    runBamdpCvarVI(
        folder,
        dom,
        alphas,
        evalRepeats,
        numInterpPts
    );
}

void cvar_vi_expected_mdp(
    domain dom,
    std::string folder,
    std::vector<float> alphas,
    int evalRepeats,
    int numInterpPts=20
){
    runExpectedMDPCvarVI(
        folder,
        dom,
        alphas,
        evalRepeats,
        numInterpPts
    );
}

void betting_game_experiment(){
    int offlineRepeats = 2000;
    int mctsRepeats = 2000;
    std::vector<float> alphas{0.03, 0.2, 1.0};
    std::vector<std::tuple<int, int, int>> onlineTrials;
    onlineTrials.push_back(std::make_tuple(100000, 25000, 0));
    float widening = 0.2;
    std::string folder = "/path/to/save/results";
    domain bgs = bettingGameSmall();

    cvar_vi_bamdp(bgs, folder, alphas, offlineRepeats);
    cvar_vi_expected_mdp(bgs, folder, alphas, offlineRepeats);
    mcts_online_cvar_rollout_pol(bgs, folder, alphas, widening, mctsRepeats, onlineTrials);

    float lr = 0.001;
    std::vector<int> evalTrials;
    int trials = 0;
    int increment = 20e3;
    int trialsMax = 2e6;
    while(trials <= trialsMax){
        evalTrials.push_back(trials);
        trials += increment;
    }
    pg_approx(bgs, folder, alphas, offlineRepeats, evalTrials, lr);
}

void traffic_experiment(){
    std::string folder = "/path/to/save/results";
    domain traffic = trafficDomain();
    std::vector<std::tuple<int, int, int>> onlineTrials;
    onlineTrials.push_back(std::make_tuple(100000, 25000, 0));
    std::vector<float> alphas{0.03, 0.2, 1.0};
    int mctsRepeats = 2000;
    int offlineRepeats = 2000;
    float widening = 0.2;

    mcts_online_cvar_rollout_pol(traffic, folder, alphas, widening, mctsRepeats, onlineTrials);
    cvar_vi_expected_mdp(traffic, folder, alphas, offlineRepeats);

    float lr = 0.001;
    std::vector<int> evalTrials;
    int trials = 0;
    int increment = 20e3;
    int trialsMax = 2e6;
    while(trials <= trialsMax){
        evalTrials.push_back(trials);
        trials += increment;
    }
    pg_approx(traffic, folder, alphas, offlineRepeats, evalTrials, lr);
}

int main()
{
    traffic_experiment();
    betting_game_experiment();
    return 0;
}
