#include "Env.h"
#include <random>
#include <algorithm>
#include <numeric>
#include <fstream>
#include <filesystem>
#include <iostream>
#include <sstream>
#include <stdexcept>

FiniteStateFiniteActionMDP::FiniteStateFiniteActionMDP(int H, int S, int A, std::mt19937& gen)
    : H(H), S(S), A(A), t(0), gen(gen) {

    std::exponential_distribution<float> exp_dist(1.0f);
    std::uniform_real_distribution<float> uni_dist(0.0f, 1.0f);

    P.resize(H, std::vector<std::vector<std::vector<float>>>(S, std::vector<std::vector<float>>(A, std::vector<float>(S))));
    R.resize(H, std::vector<std::vector<float>>(S, std::vector<float>(A)));

    for (int h = 0; h < H; ++h) {
        for (int s = 0; s < S; ++s) {
            for (int a = 0; a < A; ++a) {
                float sum = 0.0f;
                for (int s2 = 0; s2 < S; ++s2) {
                    P[h][s][a][s2] = exp_dist(gen);
                    sum += P[h][s][a][s2];
                }
                for (int s2 = 0; s2 < S; ++s2) {
                    P[h][s][a][s2] /= sum;
                }
                R[h][s][a] = uni_dist(gen);
            }
        }
    }
}

FiniteStateFiniteActionMDP::FiniteStateFiniteActionMDP(int H, int S, int A,
                                                       const std::string& transition_file,
                                                       const std::string& reward_file)
    : H(H), S(S), A(A), t(0), gen(*(new std::mt19937(0))) // Dummy RNG, unused here
{
    P.resize(H, std::vector<std::vector<std::vector<float>>>(S, std::vector<std::vector<float>>(A, std::vector<float>(S))));
    R.resize(H, std::vector<std::vector<float>>(S, std::vector<float>(A)));

    // Load transition
    std::ifstream infile_P(transition_file);
    if (!infile_P.is_open()) throw std::runtime_error("Cannot open transition file");

    std::string line;
    int h = 0, s = 0, a = 0;
    while (std::getline(infile_P, line)) {
        std::stringstream ss(line);
        std::string val;
        int s2 = 0;
        while (std::getline(ss, val, ',')) {
            P[h][s][a][s2++] = std::stof(val);
        }
        if (++a == A) { a = 0; if (++s == S) { s = 0; ++h; } }
    }

    // Load reward
    std::ifstream infile_R(reward_file);
    if (!infile_R.is_open()) throw std::runtime_error("Cannot open reward file");

    h = 0; s = 0;
    while (std::getline(infile_R, line)) {
        std::stringstream ss(line);
        std::string val;
        int a = 0;
        while (std::getline(ss, val, ',')) {
            R[h][s][a++] = std::stof(val);
        }
        if (++s == S) { s = 0; ++h; }
    }
}

int FiniteStateFiniteActionMDP::reset() {
    std::uniform_int_distribution<> dist(0, S - 1);
    state = dist(gen);
    t = 0;
    return state;
}

std::pair<int, float> FiniteStateFiniteActionMDP::step(int action) {
    int s = state;
    std::discrete_distribution<int> dist(P[t][s][action].begin(), P[t][s][action].end());
    state = dist(gen);
    float reward = R[t][s][action];
    t++;
    return {state, reward};
}

void FiniteStateFiniteActionMDP::save_env(int n, const std::string& folder) {
    std::filesystem::create_directory(folder);

    std::ofstream pfile(folder + "/env_" + std::to_string(n) + "_p.csv");
    for (int h = 0; h < H; ++h)
        for (int s = 0; s < S; ++s)
            for (int a = 0; a < A; ++a) {
                for (int s2 = 0; s2 < S; ++s2)
                    pfile << P[h][s][a][s2] << (s2 + 1 == S ? "\n" : ",");
            }

    std::ofstream rfile(folder + "/env_" + std::to_string(n) + "_r.csv");
    for (int h = 0; h < H; ++h)
        for (int s = 0; s < S; ++s)
            for (int a = 0; a < A; ++a)
                rfile << R[h][s][a] << "\n";
}

void FiniteStateFiniteActionMDP::load_env(int n, const std::string& folder) {
    std::ifstream pfile(folder + "/env_" + std::to_string(n) + "_p.csv");
    for (int h = 0; h < H; ++h)
        for (int s = 0; s < S; ++s)
            for (int a = 0; a < A; ++a)
                for (int s2 = 0; s2 < S; ++s2) {
                    char comma;
                    pfile >> P[h][s][a][s2];
                    if (s2 + 1 != S) pfile >> comma;
                }

    std::ifstream rfile(folder + "/env_" + std::to_string(n) + "_r.csv");
    for (int h = 0; h < H; ++h)
        for (int s = 0; s < S; ++s)
            for (int a = 0; a < A; ++a)
                rfile >> R[h][s][a];
}

std::tuple<std::vector<std::vector<std::vector<float>>>, std::vector<std::vector<float>>>
FiniteStateFiniteActionMDP::full_value_gen(const std::vector<std::vector<std::vector<float>>>& actions) {
    std::vector<std::vector<std::vector<float>>> Q(H, std::vector<std::vector<float>>(S, std::vector<float>(A, 0.0f)));
    std::vector<std::vector<float>> V(H + 1, std::vector<float>(S, 0.0f));

    for (int h = H - 1; h >= 0; --h) {
        for (int s = 0; s < S; ++s) {
            for (int a = 0; a < A; ++a) {
                float expected = 0.0f;
                for (int s2 = 0; s2 < S; ++s2)
                    expected += P[h][s][a][s2] * V[h + 1][s2];
                Q[h][s][a] = R[h][s][a] + expected;
            }

            float v_val = 0.0f;
            for (int a = 0; a < A; ++a)
                v_val += static_cast<float>(actions[h][s][a]) * Q[h][s][a];
            V[h][s] = v_val;
        }
    }

    return {Q, V};
}

std::vector<float> FiniteStateFiniteActionMDP::value_gen(const std::vector<std::vector<std::vector<float>>>& actions) {
    std::vector<std::vector<std::vector<float>>> Q(H, std::vector<std::vector<float>>(S, std::vector<float>(A, 0.0f)));
    std::vector<std::vector<float>> V(H + 1, std::vector<float>(S, 0.0f));

    for (int h = H - 1; h >= 0; --h) {
        for (int s = 0; s < S; ++s) {
            for (int a = 0; a < A; ++a) {
                float expected = 0.0f;
                for (int s2 = 0; s2 < S; ++s2)
                    expected += P[h][s][a][s2] * V[h + 1][s2];
                Q[h][s][a] = R[h][s][a] + expected;
            }

            float v_val = 0.0f;
            for (int a = 0; a < A; ++a)
                v_val += static_cast<float>(actions[h][s][a]) * Q[h][s][a];
            V[h][s] = v_val;
        }
    }

    return V[0];
}

std::tuple<std::vector<float>, std::vector<std::vector<std::vector<float>>>, std::vector<std::vector<std::vector<float>>>>
FiniteStateFiniteActionMDP::best_gen() {
    std::vector<std::vector<std::vector<float>>> Q(H, std::vector<std::vector<float>>(S, std::vector<float>(A, 0.0f)));
    std::vector<std::vector<float>> V(H + 1, std::vector<float>(S, 0.0f));
    std::vector<std::vector<std::vector<float>>> actions(H, std::vector<std::vector<float>>(S, std::vector<float>(A, 0.0f)));

    for (int h = H - 1; h >= 0; --h) {
        for (int s = 0; s < S; ++s) {
            for (int a = 0; a < A; ++a) {
                float expected = 0.0;
                for (int s2 = 0; s2 < S; ++s2)
                    expected += P[h][s][a][s2] * V[h + 1][s2];
                Q[h][s][a] = R[h][s][a] + expected;
            }

            int best_a = 0;
            float max_val = Q[h][s][0];
            for (int a = 1; a < A; ++a) {
                if (Q[h][s][a] > max_val) {
                    max_val = Q[h][s][a];
                    best_a = a;
                }
            }

            actions[h][s][best_a] = 1.0f;
            V[h][s] = max_val;
        }
    }

    return {V[0], actions, Q};
}
