#include <iostream>
#include <cmath>
#include <string>
#include <unordered_map>
#include <limits>
#include <algorithm>
#include "state.h"
#include "mdp.h"
#include "utils.h"
#include "gurobi_c++.h"
#include "cvar_value_iteration.h"
#include "cvar_hist.h"

CvarValueIteration::CvarValueIteration(int numPts_){
    alphaVals = getAlphaValues(numPts_);
    valueComputed = false;
    maxT = 0;
}

std::vector<float> CvarValueIteration::getAlphaValues(int numPts){
    std::vector<float> alphas = logspace(0.0, 1.0, numPts);
    for(unsigned int i = 0; i < alphas.size(); i++){
        alphas[i] -= 1.0;
        alphas[i] /= 10.0;
    }
    alphas.push_back(1.0);
    return alphas;
}

std::tuple<std::string, std::unordered_map<State, float, StateHash>> CvarValueIteration::getOptimalAction(
    std::shared_ptr<MDP> pMDP,
    State currentState,
    float currentAlpha,
    GRBEnv env
){

    std::unordered_map<std::string, std::string> stateMap = currentState.getStateMapping();
    stateMap["alpha"] = std::to_string(currentAlpha);
    State currentAugState(stateMap);

    std::unordered_map<State, float, StateHash> bestPerturbation;
    float bestValue = -std::numeric_limits<float>::max();
    std::string bestAction;
    for(std::string act : pMDP->getEnabledActions(currentState)){
        GRBModel model = getCvarLP(currentAugState, act, cvarValueFunction, env, *pMDP);
        model.optimize();

        float Q = model.get(GRB_DoubleAttr_ObjVal);
        std::unordered_map<State, float, StateHash> perturbation;
        for(auto kv : pMDP->getTransitionProbs(currentState, act)){
            GRBVar pert = model.getVarByName(kv.first.toString());
            perturbation[kv.first] = pert.get(GRB_DoubleAttr_X);
        }
        if(Q > bestValue){
            bestValue = Q;
            bestAction = act;
            bestPerturbation = perturbation;
        }
    }

    return std::make_tuple(bestAction, bestPerturbation);
}

/* Execute an episode using the cvar value computed on the expected MDP.
Note that value iteration must be run on the expected MDP before executing this
function.

Args:
    pExpectedMDP: a pointer to the expected MDP used to compute optimal actions.
    pTrueMDP: a pointer to the true MDP, used to sample transitions during the
        episode.
    initialState: the initial state in the MDP.
    initialAlpha: the initial alpha value.
*/
CvarHist CvarValueIteration::executeEpisode(
        std::shared_ptr<MDP> pExpectedMDP,
        std::shared_ptr<MDP> pTrueMDP,
        State initialState,
        float initialAlpha
){
    if(!valueComputed){
        std::cout << "Cannot execute episode. Must run value iteration first." << std::endl;
        CvarHist emptyHist;
        return emptyHist;
    }

    CvarHist history;
    State currentState = initialState;
    float currentAlpha = initialAlpha;

    GRBEnv env = GRBEnv();
    env.start();
    env.set(GRB_IntParam_OutputFlag, 0);

    for(int t = 0; t<= maxT; t++){
        std::string action;
        std::unordered_map<State, float, StateHash> perturbation;

        // optimise worst-case on zero prob paths
        if(currentAlpha < 1e-6){
            std::tie(action, perturbation) = getOptimalAction(
                                                    pExpectedMDP,
                                                    currentState,
                                                    1e-6,
                                                    env
                                            );

        // otherwise optimise action using value function
        }else{
            std::tie(action, perturbation) = getOptimalAction(
                                                    pExpectedMDP,
                                                    currentState,
                                                    currentAlpha,
                                                    env
                                            );
        }


        // append to the history
        history.addTransition(
                currentState,
                action,
                currentAlpha,
                pExpectedMDP->getTransitionProbs(currentState, action),
                perturbation,
                pExpectedMDP->getReward(currentState, action));

        // Sample the next state according to the best action
        State nextState = pTrueMDP->sampleSuccessor(currentState, action);

        // Update the alpha value according to the best perturbation.
        currentAlpha *= perturbation[nextState];
        if(currentAlpha > 1.0){
            currentAlpha = 1.0;
        }

        // set the current state the the next state.
        currentState = nextState;
    }

    return history;
}

/* Execute an episode using the cvar value computed on the BAMDP.
Note that value iteration must be run on the BAMDP before executing this
function.

Args:
    pBamdp: a pointer to the BAMDP used to compute optimal actions.
    pTrueMDP: a pointer to the true MDP, used to sample transitions during the
        episode.
    initialState: the initial state in the MDP - this state does not include the
        history.
    initialAlpha: the initial alpha value.
*/
CvarHist CvarValueIteration::executeBamdpEpisode(
        std::shared_ptr<MDP> pBamdp,
        std::shared_ptr<MDP> pTrueMDP,
        State initialState,
        float initialAlpha
){
    if(!valueComputed){
        std::cout << "Cannot execute episode. Must run value iteration first." << std::endl;
        CvarHist emptyHist;
        return emptyHist;
    }

    std::unordered_map<std::string, std::string> stateMap;
    stateMap["history"] = initialState.toString();
    stateMap["t"] = "0";
    State initialHistoryState(stateMap);

    CvarHist history;
    State currentHistoryState = initialHistoryState;
    State currentState = initialState;
    float currentAlpha = initialAlpha;

    GRBEnv env = GRBEnv();
    env.start();
    env.set(GRB_IntParam_OutputFlag, 0);

    for(int t = 0; t<= maxT; t++){
        std::string action;
        std::unordered_map<State, float, StateHash> perturbation;

        // if alpha is at or near zero switch to optimising worst-case
        if(currentAlpha < 1e-6){
            std::tie(action, perturbation) = getOptimalAction(
                                                    pBamdp,
                                                    currentHistoryState,
                                                    1e-6,
                                                    env
                                            );

        // otherwise optimise action using value function
        }else{
            std::tie(action, perturbation) = getOptimalAction(
                                                    pBamdp,
                                                    currentHistoryState,
                                                    currentAlpha,
                                                    env
                                            );
        }


        // append to the history
        history.addTransition(
                currentHistoryState,
                action,
                currentAlpha,
                pBamdp->getTransitionProbs(currentHistoryState, action),
                perturbation,
                pBamdp->getReward(currentHistoryState, action));

        // Sample the next state according to the best action
        State nextState = pTrueMDP->sampleSuccessor(currentState, action);

        // Update the alpha value according to the best perturbation.
        stateMap["history"] = getNewHistory(currentHistoryState, action, nextState);
        stateMap["t"] = std::to_string(std::stoi(currentHistoryState.getValue("t"))+1);
        State nextHistoryState(stateMap);

        currentAlpha *= perturbation[nextHistoryState];
        if(currentAlpha > 1.0){
            currentAlpha = 1.0;
        }

        // set the current state the the next state.
        currentState = nextState;
        currentHistoryState = nextHistoryState;
    }

    return history;
}


GRBModel CvarValueIteration::getCvarLP(
    State currentAugState,
    std::string act,
    std::unordered_map<State, float, StateHash>& cvarValue,
    GRBEnv& env,
    MDP& m
){
    State currentState = augToNormalState(currentAugState);
    float currentAlpha = std::stof(currentAugState.getValue("alpha"));
    GRBModel model = GRBModel(env);
    float pertMax = 1.0/currentAlpha;
    GRBLinExpr probSum;

    // initialise the Q-value to equal the reward for this state action pair
    GRBLinExpr qVal = m.getReward(currentState, act);

    int i = 0;
    for(auto pair : m.getTransitionProbs(currentState, act)){
        State nextState = pair.first;
        float nextProb = pair.second;

        // separate variable for the perturbation of each successor state
        GRBVar perturbation = model.addVar(0.0, pertMax, 0.0, GRB_CONTINUOUS, nextState.toString());
        probSum += perturbation * nextProb;

        // declare lambdas which are coefficients for the linear piecewise approx.
        GRBLinExpr sumLambdaAlpha;
        GRBLinExpr sumLambdaValue;
        GRBLinExpr sumLambda;
        std::vector<GRBVar> lambdas;
        std::vector<double> wts;
        for(auto alpha : alphaVals){
            GRBVar lambda = model.addVar(0.0, 1.0, 0.0, GRB_CONTINUOUS);
            lambdas.push_back(lambda);
            wts.push_back(1.0);
            sumLambda += lambda;

            // interpolation for the value of perturbation applied
            sumLambdaAlpha += lambda * alpha / currentAlpha;

            // interpolation for the value of the successor state
            std::unordered_map<std::string, std::string> stateMap = nextState.getStateMapping();
            stateMap["alpha"] = std::to_string(alpha);
            State nextAugState = State(stateMap);
            sumLambdaValue += lambda * cvarValue[nextAugState];
        }

        // lambdas must sum to 1 and sos2 constr for linear piecewise
        model.addConstr(sumLambda == 1.0);
        model.addConstr(sumLambdaAlpha == perturbation);
        model.addSOS(&lambdas[0], &wts[0], lambdas.size(), GRB_SOS_TYPE2);

        // Equation 10. Risk-Sensitive and Robust Decision-Making: a
        // CVaR Optimization Approach, 2015
        qVal += sumLambdaValue * nextProb / currentAlpha;
        i++;
    }
    model.addConstr(probSum == 1.0, "Probability sum");
    model.setObjective(qVal, GRB_MINIMIZE);
    return model;
}

State CvarValueIteration::augToNormalState(State augState){
    std::unordered_map<std::string, std::string> stateMap = augState.getStateMapping();
    stateMap.erase("alpha");
    return State(stateMap);
}

void CvarValueIteration::cvarBackup(
    State currentAugState,
    std::unordered_map<State, float, StateHash>& cvarValue,
    GRBEnv env,
    MDP& m
){
    State currentState = augToNormalState(currentAugState);
    float currentAlpha = std::stof(currentAugState.getValue("alpha"));
    if(currentAlpha == 0.0){
        return;
    }

    // separately compute the Q-value for each action using an LP to find worst
    // adversary perturbation
    float bestValue = -std::numeric_limits<float>::max();
    std::string bestAction;
    std::unordered_map<State, float, StateHash> bestPerturbation;
    for(auto act : m.getEnabledActions(currentState)){
        GRBModel model = getCvarLP(currentAugState, act, cvarValue, env, m);
        model.optimize();

        // record the Q-value for the best action for the agent
        float Q = model.get(GRB_DoubleAttr_ObjVal);
        std::unordered_map<State, float, StateHash> perturbation;
        for(auto kv : m.getTransitionProbs(currentState, act)){
            GRBVar pert = model.getVarByName(kv.first.toString());
            perturbation[kv.first] = pert.get(GRB_DoubleAttr_X);
        }
        if(Q > bestValue){
            bestValue = Q;
            bestAction = act;
            bestPerturbation = perturbation;
        }
    }

    // the values we backup for the linear interpolation are alpha * value
    // see Page 6. Risk-Sensitive and Robust Decision-Making: a
    // CVaR Optimization Approach, 2015
    cvarValue[currentAugState] = bestValue*currentAlpha;

    if(cvarPolicy.find(currentState) == cvarPolicy.end()){
        std::map<std::string, std::string> pol;
        pol[currentAugState.getValue("alpha")] = bestAction;
        cvarPolicy[currentState] = pol;

        std::map<std::string, std::unordered_map<State, float, StateHash>> pertPol;
        pertPol[currentAugState.getValue("alpha")] = bestPerturbation;
        cvarPerturbationPolicy[currentState] = pertPol;
    }else{
        cvarPolicy[currentState][currentAugState.getValue("alpha")] = bestAction;
        cvarPerturbationPolicy[currentState][currentAugState.getValue("alpha")] = bestPerturbation;
    }
}

/* Performs cvar value iteration on the MDP provided. This assumes that the
problem is finite horizon. This function assumes that the goal of the agent
is to maximise the reward.

Args:
    m: the MDP to perform value iteration upon.

Returns:
    tuple containing the value as a map from states with timestep and alpha
    value to values.
*/
std::unordered_map<State, float, StateHash> CvarValueIteration::valueIteration(MDP& m){
    std::vector<State> states = m.enumerateStates();
    std::unordered_map<State, float, StateHash> cvarValue;

    for(State s : states){
        std::vector<std::string> sf = s.getStateFactors();
        if(std::find(sf.begin(), sf.end(), "t") == sf.end()){
            std::cerr << "Error: Need MDP states to have a t state factor for finite horizon CVaR value iteration" << std::endl;
            std::exit(-1);
        }
    }

    // create new set of states which includes the alpha value
    std::vector<State> augmentedStates;
    for(State s : states){
        std::unordered_map<std::string, std::string> stateMap = s.getStateMapping();

        if(std::stoi(stateMap["t"]) > maxT){
            maxT = std::stoi(stateMap["t"]);
        }

        for(float alpha : alphaVals){
            stateMap["alpha"] = std::to_string(alpha);
            State augState(stateMap);
            augmentedStates.push_back(augState);
        }
    }

    // set the cvar value to zero at the final time step
    for(State s : augmentedStates){
        if(std::stoi(s.getValue("t")) == maxT){
            std::unordered_map<std::string, std::string> stateMap = s.getStateMapping();
            stateMap["t"] = std::to_string(maxT);
            cvarValue[State(stateMap)] = 0.0;
        }
    }

    // loop backwards through the horizon
    GRBEnv env = GRBEnv();
    env.start();
    env.set(GRB_IntParam_OutputFlag, 0);
    for(int t = maxT - 1; t >= 0; t--){
        std::cout << "t: " << t << std::endl;
        for(State s : augmentedStates){
            if(std::stoi(s.getValue("t")) == t){
                std::unordered_map<std::string, std::string> stateMap = s.getStateMapping();
                stateMap["t"] = std::to_string(t);
                State currentAugState(stateMap);
                cvarBackup(currentAugState, cvarValue, env, m);
            }
        }
    }

    valueComputed = true;
    cvarValueFunction = cvarValue;
    return cvarValue;
}
