#include "Qhoeffding.h"
#include <algorithm>
#include <cmath>
#include <numeric>
#include <iostream>

Qlearning_gen::Qlearning_gen(FiniteStateFiniteActionMDP& mdp, float c, int total_episodes)
    : mdp(mdp), c(c), total_episodes(total_episodes) {

    // Resize and initialize all member vectors
    V_func.resize(mdp.H + 1, std::vector<float>(mdp.S, 0.0f));
    V_next.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    global_Q.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A)));
    for (int i = 0; i < mdp.H; ++i) {
        for (int s = 0; s < mdp.S; ++s) {
            for (int a = 0; a < mdp.A; ++a) {
                global_Q[i][s][a] = static_cast<float>(mdp.H - i);
            }
        }
    }

    N.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));
    n.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));
}

std::vector<std::vector<std::vector<float>>> Qlearning_gen::choose_action() {
    std::vector<std::vector<std::vector<float>>> actions(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int step = 0; step < mdp.H; ++step) {
        for (int state = 0; state < mdp.S; ++state) {
            // Find the index of the maximum Q-value for the current state-step (equivalent to np.argmax)
            auto max_it = std::max_element(global_Q[step][state].begin(), global_Q[step][state].end());
            int best_action = std::distance(global_Q[step][state].begin(), max_it);
            actions[step][state][best_action] = 1.0f;
        }
    }
    return actions;
}

std::pair<std::vector<std::vector<std::vector<float>>>, int> Qlearning_gen::run_episode() {
    auto actions_policy = choose_action();
    int state = mdp.reset();
    int state_init = state;
    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int step = 0; step < mdp.H; ++step) {
        auto max_it = std::max_element(actions_policy[step][state].begin(), actions_policy[step][state].end());
        int action = std::distance(actions_policy[step][state].begin(), max_it);

        auto [next_state, reward] = mdp.step(action);

        n[step][state][action] = 1;

        V_next[step][state][action] = V_func[step + 1][next_state];
        
        rewards[step][state][action] = reward;
        state = next_state;
    }
    return {rewards, state_init};
}

void Qlearning_gen::update_Q(const std::vector<std::vector<std::vector<float>>>& rewards) {
    int H = mdp.H;
    for (int h = 0; h < H; ++h) {
        for (int s = 0; s < mdp.S; ++s) {
            for (int a = 0; a < mdp.A; ++a) {
                if (n[h][s][a] == 0) {
                    continue;
                } else {
                    N[h][s][a]++;
                    int N_h_k = N[h][s][a];
                    float step_size = static_cast<float>(H + 1) / (H + N_h_k);
                    float ucb_bonus = c * (H - h - 1) * std::sqrt(static_cast<float>(H) / N_h_k);
                    global_Q[h][s][a] = (1.0f - step_size) * global_Q[h][s][a] +
                                        step_size * (rewards[h][s][a] + V_next[h][s][a] + ucb_bonus);
                }
            }
        }
    }
    
    // Reset temporary visit counts (n) for the next episode
    for(auto& v1 : n) {
        for(auto& v2 : v1) {
            std::fill(v2.begin(), v2.end(), 0);
        }
    }
}

std::tuple<
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>
> Qlearning_gen::learn() {
    float regret_cum = 0.0f;
    auto [best_value, best_policy, best_Q] = mdp.best_gen();

    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));
    for (int h = 0; h < mdp.H; ++h) {
        for (int s = 0; s < mdp.S; ++s) {
            V_func[h][s] = *std::max_element(global_Q[h][s].begin(), global_Q[h][s].end());
        }
    }
    auto actions_policy = choose_action();
    
    for (int episode = 0; episode < total_episodes; ++episode) {
        auto [run_reward, state_init] = run_episode();
        
        auto value_vec = mdp.value_gen(actions_policy);
        float current_value = value_vec[state_init];

        regret_cum += best_value[state_init] - current_value;
        regret.push_back(regret_cum / (episode + 1));
        raw_gap.push_back(best_value[state_init] - current_value);

        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                 auto max_it = std::max_element(actions_policy[h][s].begin(), actions_policy[h][s].end());
                 int a = std::distance(actions_policy[h][s].begin(), max_it);
                if (rewards[h][s][a] == 0.0f && run_reward[h][s][a] != 0.0f) {
                    rewards[h][s][a] = run_reward[h][s][a];
                }
            }
        }

        update_Q(rewards);
        actions_policy = choose_action();

        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                 float max_q = *std::max_element(global_Q[h][s].begin(), global_Q[h][s].end());
                 V_func[h][s] = std::min(static_cast<float>(mdp.H - h), max_q);
            }
        }
    }

    auto final_value_vec = mdp.value_gen(actions_policy);
    
    return {best_value, best_Q, final_value_vec, global_Q, raw_gap};
}