#ifndef cvar_value_iteration
#define cvar_value_iteration
#include <string>
#include "state.h"
#include "mdp.h"
#include "gurobi_c++.h"
#include "cvar_hist.h"

/* Class to implement the Bayes adaptive Monte Carlo search approach to solving
Bayes adaptive MDPs.

Attributes:
    maxTrials: the maximum number of trials to be used to find the next action.
*/
class CvarValueIteration {
private:
    std::vector<float> alphaVals;
    bool valueComputed;
    int maxT;
    std::unordered_map<State, float, StateHash> cvarValueFunction;
    std::vector<float> getAlphaValues(int numPts);
    State augToNormalState(State augState);
    std::unordered_map<State, std::map<std::string, std::string>, StateHash> cvarPolicy;
    std::unordered_map<State, std::map<std::string, std::unordered_map<State, float, StateHash>>, StateHash> cvarPerturbationPolicy;

    void cvarBackup(
        State currentAugState,
        std::unordered_map<State, float, StateHash>& cvarValue,
        GRBEnv env,
        MDP& m
    );

    GRBModel getCvarLP(
        State currentAugState,
        std::string act,
        std::unordered_map<State, float, StateHash>& cvarValue,
        GRBEnv& env,
        MDP& m
    );

public:
    CvarValueIteration(int numPts_);
    std::unordered_map<State, float, StateHash> valueIteration(MDP& m);

    std::tuple<std::string, std::unordered_map<State, float, StateHash>> getOptimalAction(
        std::shared_ptr<MDP> pMDP,
        State currentState,
        float currentAlpha,
        GRBEnv env
    );

    CvarHist executeEpisode(
        std::shared_ptr<MDP> pExpectedMDP,
        std::shared_ptr<MDP> pTrueMDP,
        State initialState,
        float initialAlpha
    );

    CvarHist executeBamdpEpisode(
            std::shared_ptr<MDP> pBamdp,
            std::shared_ptr<MDP> pTrueMDP,
            State initialState,
            float initialAlpha
    );

    std::unordered_map<State, std::map<std::string, std::string>, StateHash> getCvarPolicy(){
        return cvarPolicy;
    }

    std::unordered_map<State, std::map<std::string, std::unordered_map<State, float, StateHash>>, StateHash> getCvarPerturbationPolicy(){
        return cvarPerturbationPolicy;
    }

};

#endif
