#include "AMB.h"
#include <algorithm>
#include <cmath>
#include <numeric>
#include <limits>

Qlearning_gen_AMB::Qlearning_gen_AMB(FiniteStateFiniteActionMDP& mdp, float c, int total_episodes)
    : mdp(mdp), c(c), total_episodes(total_episodes) {

    // Resize and initialize all member vectors
    VU.resize(mdp.H + 1, std::vector<float>(mdp.S, 0.0f));
    VL.resize(mdp.H + 1, std::vector<float>(mdp.S, 0.0f));

    QU.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A)));
    QL.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int i = 0; i < mdp.H; ++i) {
        for (int s = 0; s < mdp.S; ++s) {
            for (int a = 0; a < mdp.A; ++a) {
                QU[i][s][a] = static_cast<float>(mdp.H - i);
            }
        }
    }

    N.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));
    n.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));

    A_valid.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 1)));
    G.resize(mdp.H, std::vector<int>(mdp.S, 0));

    episode_state.resize(mdp.H + 1, 0);
    episode_action.resize(mdp.H + 1, 0);
}

std::vector<std::vector<std::vector<float>>> Qlearning_gen_AMB::choose_action() {
    std::vector<std::vector<std::vector<float>>> actions(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int step = 0; step < mdp.H; ++step) {
        for (int state = 0; state < mdp.S; ++state) {
            float max_diff = -std::numeric_limits<float>::infinity();
            int best_action = 0;

            for (int a = 0; a < mdp.A; ++a) {
                if (A_valid[step][state][a] > 0) {
                    float diff = QU[step][state][a] - QL[step][state][a];
                    if (diff > max_diff) {
                        max_diff = diff;
                        best_action = a;
                    }
                }
            }
            actions[step][state][best_action] = 1.0f;
        }
    }
    return actions;
}

std::pair<std::vector<std::vector<std::vector<float>>>, int> Qlearning_gen_AMB::run_episode() {
    auto actions_policy = choose_action();
    int state = mdp.reset();
    int state_init = state;
    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));
    
    episode_state[0] = state_init;

    for (int step = 0; step < mdp.H; ++step) {
        auto max_it = std::max_element(actions_policy[step][state].begin(), actions_policy[step][state].end());
        int action = std::distance(actions_policy[step][state].begin(), max_it);

        auto [next_state, reward] = mdp.step(action);

        episode_state[step] = state;
        episode_action[step] = action;

        n[step][state][action] = 1;
        
        rewards[step][state][action] = reward;
        state = next_state;
    }
    episode_state[mdp.H] = state; // Store final state

    return {rewards, state_init};
}

int Qlearning_gen_AMB::first_undecided_state(int step) {
    if (step == mdp.H - 1) {
        return mdp.H;
    }
    for (int i = step + 1; i < mdp.H; ++i) {
        int s_i = episode_state[i];
        if (G[i][s_i] == 0) {
            return i;
        }
    }
    return mdp.H;
}

void Qlearning_gen_AMB::update_QAMB(const std::vector<std::vector<std::vector<float>>>& rewards) {
    int H = mdp.H;
    for (int h = H - 1; h >= 0; --h) {
        // In the recorded trajectory, (s, a) is (episode_state[h], episode_action[h])
        int s = episode_state[h];
        int a = episode_action[h];

        if (n[h][s][a] == 0) continue;
        if (G[h][s] == 1) continue;

        N[h][s][a]++;
        int N_h_k = N[h][s][a];
        float step_size = static_cast<float>(H + 1) / (H + N_h_k);
        float ucb_bonus = 2 * c * (H - h - 1) * std::sqrt(static_cast<float>(H) / N_h_k);

        int h_prime = first_undecided_state(h);
        
        float sum_rewards = 0.0f;
        for (int i = h; i < h_prime; ++i) {
            int s_i = episode_state[i];
            int a_i = episode_action[i];
            sum_rewards += rewards[i][s_i][a_i];
        }

        int s_prime = episode_state[h_prime];
        float future_VU = VU[h_prime][s_prime];
        float future_VL = VL[h_prime][s_prime];

        // Update QU
        float qu_update = (1.0f - step_size) * QU[h][s][a] + step_size * (sum_rewards + future_VU + ucb_bonus);
        QU[h][s][a] = std::min(static_cast<float>(H), qu_update);

        // Update QL
        float ql_update = (1.0f - step_size) * QL[h][s][a] + step_size * (sum_rewards + future_VL - ucb_bonus);
        QL[h][s][a] = std::max(0.0f, ql_update);
    }

    // Reset temporary visit counts and episode trajectory
    for(auto& v1 : n) for(auto& v2 : v1) std::fill(v2.begin(), v2.end(), 0);
    std::fill(episode_state.begin(), episode_state.end(), 0);
    std::fill(episode_action.begin(), episode_action.end(), 0);
}

std::tuple<
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>
> Qlearning_gen_AMB::learn() {
    float regret_cum = 0.0f;
    auto [best_value, best_policy, best_Q] = mdp.best_gen();

    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));
    
    for (int h = 0; h < mdp.H; ++h) {
        for (int s = 0; s < mdp.S; ++s) {
            VU[h][s] = *std::max_element(QU[h][s].begin(), QU[h][s].end());
            VL[h][s] = *std::max_element(QL[h][s].begin(), QL[h][s].end());
        }
    }
    auto actions_policy = choose_action();
    
    std::vector<float> last_value_vec;

    for (int episode = 0; episode < total_episodes; ++episode) {
        auto [run_reward, state_init] = run_episode();
        
        last_value_vec = mdp.value_gen(actions_policy);
        float current_value = last_value_vec[state_init];

        regret_cum += best_value[state_init] - current_value;
        regret.push_back(regret_cum / (episode + 1));
        raw_gap.push_back(best_value[state_init] - current_value);

        // Update the global reward table
        for (int h = 0; h < mdp.H; ++h) {
            int s = episode_state[h];
            int a = episode_action[h];
            if (rewards[h][s][a] == 0.0f && run_reward[h][s][a] != 0.0f) {
                rewards[h][s][a] = run_reward[h][s][a];
            }
        }
        
        update_QAMB(rewards);
        actions_policy = choose_action();

        // Update VU and VL from QU and QL
        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                 VU[h][s] = *std::max_element(QU[h][s].begin(), QU[h][s].end());
                 VL[h][s] = *std::max_element(QL[h][s].begin(), QL[h][s].end());
            }
        }

        // Action elimination step
        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                for (int a = 0; a < mdp.A; ++a) {
                    if (QU[h][s][a] < VL[h][s]) {
                        A_valid[h][s][a] = 0;
                    }
                }
            }
        }
        
        // Mark decided states
        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                if (std::accumulate(A_valid[h][s].begin(), A_valid[h][s].end(), 0) == 1) {
                    G[h][s] = 1;
                }
            }
        }
    }
    
    return {best_value, best_Q, last_value_vec, QU, raw_gap};
}