#pragma once
#include "overall_define.h"
#include <cstring>
#include <tlx/thread_pool.hpp>

template<typename Poker, bool atomic_update = false>
class RegretStrategy{
public:
    using Poker_t = Poker;
    // RegretStrategy(int player, const uint64_t round_buckets[Poker_t::num_rounds]);

    // void get_probability(const Sequence<Poker_t>& seq, uint64_t bucket, double* probabilities) const;

    // double *get_average_probability(const Sequence<Poker_t>& seq, uint64_t bucket) const;

    // double *get_regret(const Sequence<Poker_t>& seq, uint64_t bucket) const;

    // int get_player() const;

    // void get_normalized_average_probability(const Sequence<Poker_t>& seq, uint64_t bucket, double* probabilities) const;

    // ~RegretStrategy();

    // void debug_print_strategy()const;

    // void update_delta();

    RegretStrategy(int player, const uint64_t round_buckets[Poker_t::num_rounds]);

    int get_player() const;

    const double *get_accumulative_regret(const Sequence<Poker_t>& seq, uint64_t bucket) const;

    void get_strategy(const Sequence<Poker_t>& seq, uint64_t bucket, double* strategy) const;

    const double *get_accumulative_probability(const Sequence<Poker_t>& seq, uint64_t bucket) const;

    void get_average_strategy(const Sequence<Poker_t>& seq, uint64_t bucket, double* average_strategy) const;

    void accumulate_regret(const Sequence<Poker_t>& seq, uint64_t bucket, int action, double regret);

    void accumulate_probability(const Sequence<Poker_t>& seq, uint64_t bucket, int action, double probability);

    void debug_print_strategy()const;

    uint64_t get_strategy_size() const;

    ~RegretStrategy();
// partial
    template<bool atomic_update_ = atomic_update>
    std::enable_if<atomic_update_ && (atomic_update_ == atomic_update), void>::type 
    update_delta();

private:
    int player;
    uint64_t round_buckets[Poker_t::num_rounds];
    type::value_t ***accumulative_regret;// [internal][bucket in round][action]
    double ***accumulative_probability;// [internal][bucket in round][action]
    std::atomic<type::value_t> ***accumulative_regret_delta = nullptr;// [internal][bucket in round][action]
    std::atomic<double> ***accumulative_probability_delta = nullptr;// [internal][bucket in round][action]
    uint64_t strategy_size;
};

template<typename Poker_t, bool atomic_update>
RegretStrategy<Poker_t, atomic_update>::RegretStrategy(int player, const uint64_t round_buckets[Poker_t::num_rounds]) : player(player){
    strategy_size = 0;
    std::memcpy(this->round_buckets, round_buckets, sizeof(uint64_t)*Poker_t::num_rounds);

    accumulative_regret = new type::value_t**[Game<Poker_t>::num_internal];
    accumulative_probability = new double**[Game<Poker_t>::num_internal];
    std::memset(accumulative_regret, 0, sizeof(type::value_t**) * Game<Poker_t>::num_internal);
    std::memset(accumulative_probability, 0, sizeof(double**) * Game<Poker_t>::num_internal);

    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        // Sequence<Poker_t> seq(i);
        if (Game<Poker_t>::whose_turn[i] == player) {
            uint64_t n = round_buckets[Game<Poker_t>::round[i]];

            accumulative_regret[i] = new type::value_t*[n];
            accumulative_probability[i] = new double*[n];
            for(uint64_t j = 0; j<n; ++j){
                accumulative_regret[i][j] = new type::value_t[Game<Poker_t>::num_actions[i]];
                accumulative_probability[i][j] = new double[Game<Poker_t>::num_actions[i]];
                std::memset(accumulative_regret[i][j], 0, sizeof(type::value_t) *  Game<Poker_t>::num_actions[i]);
                std::memset(accumulative_probability[i][j], 0, sizeof(double) *  Game<Poker_t>::num_actions[i]);
            }
            strategy_size += Game<Poker_t>::num_actions[i] * n;
        }
    }

    if constexpr (!atomic_update)
        return;

    accumulative_regret_delta              = new std::atomic<type::value_t>**[Game<Poker_t>::num_internal];
    accumulative_probability_delta = new std::atomic<double>**[Game<Poker_t>::num_internal];
    std::memset(accumulative_regret_delta, 0, sizeof(std::atomic<type::value_t>**) * Game<Poker_t>::num_internal);
    std::memset(accumulative_probability_delta, 0, sizeof(std::atomic<double>**) * Game<Poker_t>::num_internal);

    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        // Sequence<Poker_t> seq(i);
        if (Game<Poker_t>::whose_turn[i] == player) {
            uint64_t n = round_buckets[Game<Poker_t>::round[i]];
            
            accumulative_regret_delta[i]              = new std::atomic<type::value_t>*[n];
            accumulative_probability_delta[i] = new std::atomic<double>*[n];
            for(uint64_t j = 0; j<n; ++j){
                accumulative_regret_delta[i][j] = new std::atomic<type::value_t>[Game<Poker_t>::num_actions[i]];
                accumulative_probability_delta[i][j] = new std::atomic<double>[Game<Poker_t>::num_actions[i]];
            }
        }
    }
}

template<typename Poker_t, bool atomic_update>
inline uint64_t RegretStrategy<Poker_t, atomic_update>::get_strategy_size() const{
    return strategy_size;
}

template<typename Poker_t, bool atomic_update>
inline int RegretStrategy<Poker_t, atomic_update>::get_player() const{
    return player;
}

template<typename Poker_t, bool atomic_update>
inline const double *RegretStrategy<Poker_t, atomic_update>::get_accumulative_regret(const Sequence<Poker_t>& seq, uint64_t bucket) const {

    return accumulative_regret[seq.get_id()][bucket];
}

template<typename Poker_t, bool atomic_update>
void RegretStrategy<Poker_t, atomic_update>::get_strategy(const Sequence<Poker_t>& seq, uint64_t bucket, double* strategy) const {

    double regret_sum_plus = 0;
    int u = seq.get_id();
    int action_size = Game<Poker_t>::num_actions[u];
    for (int i = 0; i<action_size; ++i){
        regret_sum_plus += std::max(0., accumulative_regret[u][bucket][i]);
    }

    if (overall_define::epsilon < regret_sum_plus){
        for (int i = 0; i<action_size; ++i){
            strategy[i]= std::max(0., accumulative_regret[u][bucket][i])/regret_sum_plus;
        }
    }
    else {
        std::fill(strategy, strategy+action_size, 1./action_size);
    }
}

template<typename Poker_t, bool atomic_update>
inline const double *RegretStrategy<Poker_t, atomic_update>::get_accumulative_probability(const Sequence<Poker_t>& seq, uint64_t bucket) const {

    return accumulative_probability[seq.get_id()][bucket];
}


template<typename Poker_t, bool atomic_update>
void RegretStrategy<Poker_t, atomic_update>::get_average_strategy(const Sequence<Poker_t>& seq, uint64_t bucket, double* average_strategy) const {
    int u = seq.get_id();
    int action_size = Game<Poker_t>::num_actions[u];

    double sum = 0;
    for(int i = 0; i < action_size; ++i) {
        sum += accumulative_probability[u][bucket][i];
    }

    if (sum < overall_define::epsilon) /*sum == 0*/{
        std::fill(average_strategy, average_strategy + action_size, 1./action_size);
    }
    else {
        for(int i = 0; i < action_size; ++i){
            average_strategy[i] = accumulative_probability[u][bucket][i]/sum;
        }
    }
}

template<typename Poker_t, bool atomic_update>
inline void RegretStrategy<Poker_t, atomic_update>::accumulate_regret(const Sequence<Poker_t>& seq, uint64_t bucket, int action, double regret){
    if constexpr (atomic_update){
        accumulative_regret_delta[seq.get_id()][bucket][action] += regret;
    }
    else {
        accumulative_regret[seq.get_id()][bucket][action] += regret;
    }
}

template<typename Poker_t, bool atomic_update>
inline void RegretStrategy<Poker_t, atomic_update>::accumulate_probability(const Sequence<Poker_t>& seq, uint64_t bucket, int action, double probability){
    if constexpr (atomic_update){
        accumulative_probability_delta[seq.get_id()][bucket][action] += probability;
    }
    else {
        accumulative_probability[seq.get_id()][bucket][action] += probability;
    }
}

template<typename Poker_t, bool atomic_update>
RegretStrategy<Poker_t, atomic_update>::~RegretStrategy(){
    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        Sequence<Poker_t> seq(i);
        uint64_t n = round_buckets[Game<Poker_t>::round[i]];
        if(accumulative_regret[i] ){
            for(uint64_t j = 0; j<n; ++j){
                delete[] accumulative_regret[i][j];
                accumulative_regret[i][j] = nullptr;
            }
        }
        delete[] accumulative_regret[i];
        accumulative_regret[i] = nullptr;

        if(accumulative_probability[i]){
            for(uint64_t j = 0; j<n; ++j){
                delete[] accumulative_probability[i][j];
                accumulative_probability[i][j] = nullptr;
            }
            accumulative_probability[i] = nullptr;
        }
        delete[] accumulative_probability[i];
        accumulative_probability[i] = nullptr;
    }

    delete[] accumulative_regret;
    accumulative_regret = nullptr;

    delete[] accumulative_probability;
    accumulative_probability = nullptr;

    if constexpr (!atomic_update)
        return;

    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        Sequence<Poker_t> seq(i);
        uint64_t n = round_buckets[Game<Poker_t>::round[i]];
        if(accumulative_regret_delta[i] ){
            for(uint64_t j = 0; j<n; ++j){
                delete[] accumulative_regret_delta[i][j];
                accumulative_regret_delta[i][j] = nullptr;
            }
        }
        delete[] accumulative_regret_delta[i];
        accumulative_regret_delta[i] = nullptr;

        if(accumulative_probability_delta[i]){
            for(uint64_t j = 0; j<n; ++j){
                delete[] accumulative_probability_delta[i][j];
                accumulative_probability_delta[i][j] = nullptr;
            }
            accumulative_probability_delta[i] = nullptr;
        }
        delete[] accumulative_probability_delta[i];
        accumulative_probability_delta[i] = nullptr;
    }

    delete[] accumulative_regret_delta;
    accumulative_regret_delta = nullptr;

    delete[] accumulative_probability_delta;
    accumulative_probability_delta = nullptr;
}

template<typename Poker_t, bool atomic_update>
void RegretStrategy<Poker_t, atomic_update>::debug_print_strategy() const {

    printf("accumulative_regret:{\n");
    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        if (Game<Poker_t>::whose_turn[i] != player)
            continue;
        
        int round = Game<Poker_t>::round[i];
        int action_size = Game<Poker_t>::num_actions[i];

        printf("%d:%s\n", i, Sequence<Poker_t>(i).to_string().c_str());
        for (int j = 0; j < round_buckets[round]; ++j){

            printf("\tbucket %d", j);
            for (int k = 0; k < action_size; ++k){
                printf("%c %.09lf", k==0?':':',', accumulative_regret[i][j][k]);
            }
            printf("\n");
        }
    }
    printf("}\n");

    printf("accumulative_probability:{\n");
    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        if (Game<Poker_t>::whose_turn[i] != player)
            continue;
        
        int round = Game<Poker_t>::round[i];
        int action_size = Game<Poker_t>::num_actions[i];

        printf("%d:%s\n", i, Sequence<Poker_t>(i).to_string().c_str());
        for (uint64_t j = 0; j < round_buckets[round]; ++j){

            printf("\tbucket %d", j);
            for (int k = 0; k < action_size; ++k){
                printf("%c %.09lf", k==0?':':',', accumulative_probability[i][j][k]);
            }
            printf("\n");
        }
    }
    printf("}\n");
}

template<typename Poker_t, bool atomic_update>
template<bool atomic_update_>
std::enable_if<atomic_update_ && (atomic_update_ == atomic_update), void>::type
RegretStrategy<Poker_t, atomic_update>::update_delta(){
    static tlx::ThreadPool pool(overall_define::num_threads);

    for(int th = 0; th < overall_define::num_threads; ++th){
        pool.enqueue([th, this](){
            for(int u = th; u < Game<Poker_t>::num_internal; u += overall_define::num_threads){
                if(Game<Poker_t>::whose_turn[u] != player)
                    continue;
                
                int round = Game<Poker_t>::round[u];
                uint64_t bucket_size = round_buckets[round];
                int action_size = Game<Poker_t>::num_actions[u];

                for(uint64_t b = 0; b < bucket_size; ++b){
                    for(int a = 0; a < action_size; ++a){
                        accumulative_regret[u][b][a] += accumulative_regret_delta[u][b][a];
                        accumulative_regret_delta[u][b][a] = 0;

                        accumulative_probability[u][b][a] += accumulative_probability_delta[u][b][a];
                        accumulative_probability_delta[u][b][a] = 0;
                    }
                }
            }
        });
    }

    pool.loop_until_empty();
}