/* This file implements tests for solving BAMDPs with MCTS */

#include <iostream>
#include <string>
#include "catch.h"
#include "multimodel_mdp.h"
#include "mdp_examples.h"
#include "value_iteration.h"
#include "utils.h"


TEST_CASE("Betting game example 1"){
    MDP testMDP = *makeBettingMDP(1.0, 5);
    std::unordered_map<State, float, StateHash> value;
    std::unordered_map<State, std::string, StateHash> policy;

    std::tie(value, policy) = valueIteration(testMDP, true);

    std::unordered_map<std::string, std::string> stateMap;
    stateMap["t"] = "0";
    stateMap["money"] = "10";
    State state(stateMap);
    REQUIRE(cmpf(value[state], 35.0));
}

TEST_CASE("Betting game example 2"){
    MDP testMDP = *makeBettingMDP(0.8, 5);
    std::unordered_map<State, float, StateHash> value;
    std::unordered_map<State, std::string, StateHash> policy;

    std::tie(value, policy) = valueIteration(testMDP, true);

    std::unordered_map<std::string, std::string> stateMap;
    stateMap["t"] = "0";
    stateMap["money"] = "10";
    State state(stateMap);
    REQUIRE(cmpf(value[state], 24.6016));
}

TEST_CASE("Betting game example 3"){
    MDP testMDP = *makeBettingMDP(0., 5);
    std::unordered_map<State, float, StateHash> value;
    std::unordered_map<State, std::string, StateHash> policy;

    std::tie(value, policy) = valueIteration(testMDP, true);

    std::unordered_map<std::string, std::string> stateMap;
    stateMap["t"] = "0";
    stateMap["money"] = "10";
    State state(stateMap);
    REQUIRE(cmpf(value[state], 10.0));
}

TEST_CASE("Medical decision making"){
    int numDays = 7;
    int seed = numDays;
    MDP testMDP = *makeMedicalMDP(numDays, seed);
    std::unordered_map<State, float, StateHash> value;
    std::unordered_map<State, std::string, StateHash> policy;
    std::tie(value, policy) = valueIteration(testMDP, true);

    std::unordered_map<std::string, std::string> stateMap;
    stateMap["t"] = "0";
    stateMap["health"] = "5";
    State state(stateMap);
    REQUIRE(cmpf(value[state], 4.39669));
}


// TEST_CASE("TrafficMDP"){
//     int horizon = 8;
//     MDP testMDP = *trafficMDP(horizon);
//     std::unordered_map<State, float, StateHash> value;
//     std::unordered_map<State, std::string, StateHash> policy;
//     std::tie(value, policy) = valueIteration(testMDP, true);
//
//     std::unordered_map<std::string, std::string> stateMap;
//     stateMap["t"] = "0";
//     stateMap["x"] = "1";
//     stateMap["y"] = "0";
//     stateMap["min"] = "0";
//     State state(stateMap);
//     std::cout << value[state] << std::endl;
// }
