
#include <atomic>
#include <future>
#include <tlx/thread_pool.hpp>

// 开新线程传值
// ev 每个函数重开一个副本，子过程结束后ev即可以被regret减去了
// cfr所有线程都是一个
// 这里的chance一直没有更新，有问题，记得要改

template<typename Poker>
class ParallelCS_MCCFR{
    using Poker_t = Poker;

    class ParallelChanceVariables;
    friend class ParallelChanceVariables;

    int sample_batch;
    int parallel_round;
    ParallelChanceVariables parallel_chance_variables;
    int thread_num;
    tlx::ThreadPool pool;
    std::shared_ptr<RegretStrategy<Poker_t>> strat[Poker_t::num_players];
    std::vector<int> parallel_chances;
public:
    ParallelCS_MCCFR(int sample_batch, int parallel_round, int thread_num, std::shared_ptr<RegretStrategy<Poker_t>> strat[Poker_t::num_players]) : sample_batch(sample_batch)
                                                                                                                                                 , parallel_round(parallel_round)
                                                                                                                                                 , parallel_chance_variables(this)
                                                                                                                                                 , thread_num(std::min(thread_num, Game<Poker_t>::num_chance_round[parallel_round]))
                                                                                                                                                 , pool(this->thread_num){ // 
        std::copy(strat, strat + Poker_t::num_players, this->strat);

        parallel_chances.clear();
        parallel_chances.reserve(Game<Poker_t>::num_chance_round[parallel_round]);
        for(int u = 0; u < Game<Poker_t>::num_chance; ++u){
            if (Game<Poker_t>::round[u] == parallel_round){
                parallel_chances.push_back(u);
            }
        }
        assert(parallel_chances.size() == Game<Poker_t>::num_chance_round[parallel_round]);
        printf("thread num: %d\n", this->thread_num);
    }

    void reset(int sample_batch, int parallel_round, int thread_num, std::shared_ptr<RegretStrategy<Poker_t>> strat[Poker_t::num_players]){
        this->sample_batch = sample_batch;
        this->parallel_round = parallel_round;
        parallel_chance_variables = ParallelChanceVariables(this);
        this->thread_num = std::min(thread_num, Game<Poker_t>::num_chance_round[parallel_round]);
        pool = tlx::ThreadPool(this->thread_num);
        std::copy(strat, strat + Poker_t::num_players, this->strat);

        parallel_chances.clear();
        parallel_chances.reserve(Game<Poker_t>::num_chance_round[parallel_round]);
        for(int u = 0; u < Game<Poker_t>::num_chance; ++u){
            if (Game<Poker_t>::round[u] == parallel_round){
                parallel_chances.push_back(u);
            }
        }
        assert(parallel_chances.size() == Game<Poker_t>::num_chance_round[parallel_round]);
    }

private:
    enum class TreeFlag {
        UPPER_TREE,
        LOWER_TREE,
    };

    struct ParallelChanceVariables{

        // double *reaches[Game<Poker_t>::num_chance];// double chance_reach[Game<Poker_t>::num_chance][Poker_t::num_players]
        std::unordered_map<int, double(*)[Poker_t::num_players]> reaches;
        std::unordered_map<int, double(*)[Poker_t::num_players]> evs;
        std::unordered_map<int, std::array<double, Poker_t::num_players>> cfrs;

    private:
        ParallelCS_MCCFR& host;

    public:
        ParallelChanceVariables(ParallelCS_MCCFR* p_host) : host(*p_host){
            // std::memset(reaches, 0, sizeof(double*) * Game<Poker_t>::num_chance);

            // auto temp = new double[Game<Poker_t>::num_chance_round[host.parallel_round]][Poker_t::num_players];
            
            // int cnt = 0;
            // for(int u = 0, chance_size = Game<Poker_t>::num_chance; u < chance_size; ++u) {
            //     if (Game<Poker_t>::round[u] == host.parallel_round){
            //         reaches[u] = temp[cnt++];
            //     }
            // }
            // assert(cnt == Game<Poker_t>::num_chance_round[host.parallel_round]);

            reaches.clear();
            cfrs.clear();
            evs.clear();
            for(int u = 0; u < Game<Poker_t>::num_chance; ++u){
                if (Game<Poker_t>::round[u] == host.parallel_round){
                    reaches.insert({u, new double[host.sample_batch][Poker_t::num_players]});
                    evs.insert({u, new double[host.sample_batch][Poker_t::num_players]});
                    cfrs.insert({u, {0.}});
                }
            }
            assert(reaches.size() == Game<Poker_t>::num_chance_round[host.parallel_round]);
        }

        ~ParallelChanceVariables(){
            // 利用目标层第一个chance一定是本层最前面的特性释放空间
            // for(int u = 0, chance_size = Game<Poker_t>::num_chance; u < chance_size; ++u) {
            //     if (reaches[u]){
            //         delete[] reaches[u];
            //         break;
            //     }
            // }
            for (auto& [key, value] : reaches){
                delete[] value;
            }

            for (auto& [key, value] : evs){
                delete[] value;
            }
        }

        void compute_reach_partial_recursive(int batch_idx, const Sequence<Poker_t>& seq, const uint64_t buckets[Poker_t::num_players][Poker_t::num_rounds], double reach[Poker_t::num_players]){
            int u = seq.get_id();
            int round = Game<Poker_t>::round[u];

            if (round >= host.parallel_round && seq.is_chance()) {
                // 到达目标层的chance复制reach
                std::memcpy(reaches[u][batch_idx], reach, sizeof(double[Poker_t::num_players]));
            }
            else if(round < host.parallel_round && !seq.is_terminal()) {
                // 少于目标层且不是终端节点，向前发展后继
                int player = Game<Poker_t>::whose_turn[u];
                int action_size = Game<Poker_t>::num_actions[u];

                /* get the probabilty tuple for each player */
                double *probability = new double [action_size];
                host.strat[player]->get_strategy(seq, buckets[player][round], probability);

                double old_reach = reach[player];
                for(int i=0; i<action_size; ++i) {

                    reach[player] = old_reach * probability[i];
                    compute_reach_partial_recursive(batch_idx, seq.do_action(i), buckets, reach);
                }
                reach[player] = old_reach;

                delete[] probability;
            }

            // if 逻辑中达不到的
            // seq.is_terminal()
            // round >= host.parallel_round && !seq.is_chance()

            // 因为约束达不到的
            // round > host.parallel_round && seq.is_chance() 因为在round == host.parallel_round 的seq都不会向下发展所以达不到
        }

    };

public:

    void update_regret( const uint64_t (*buckets)[Poker_t::num_players][Poker_t::num_rounds]
                      , const type::rank_t (*ranks)[Poker_t::num_players]
                      , double reach[Poker_t::num_players]
                      , double ev[Poker_t::num_players]
                      , double cfr[Poker_t::num_players]
                      ) {
        for (int b = 0; b < sample_batch; ++b) {
            parallel_chance_variables.compute_reach_partial_recursive(b, Sequence<Poker_t>(0), buckets[b], reach);
        }

        for(const int u : parallel_chances) {
            pool.enqueue([this, u, buckets, ranks, reach, ev, cfr](){
                std::fill(parallel_chance_variables.cfrs[u].begin(), parallel_chance_variables.cfrs[u].end(), 0);
                for(int b = 0; b < sample_batch; ++b){
                    update_task_recursive<TreeFlag::LOWER_TREE>( b
                                                               , Sequence<Poker_t>(u)
                                                               , buckets[b]
                                                               , ranks[b]
                                                               , 1.
                                                               , parallel_chance_variables.reaches[u][b]
                                                               , parallel_chance_variables.evs[u][b]
                                                               , parallel_chance_variables.cfrs[u].data());
                }
            });
        }
        pool.loop_until_empty();

        for (int b = 0; b < sample_batch; ++b) {
            double ret_ev[Poker_t::num_players];
            update_task_recursive<TreeFlag::UPPER_TREE>( b
                                                       , Sequence<Poker_t>(0)
                                                       , buckets[b]
                                                       , ranks[b]
                                                       , 1.
                                                       , reach
                                                       , ret_ev
                                                       , cfr
                                                       );
            for(int p = 0; p<Poker_t::num_players; ++p){
                ev[p] += ret_ev[p];
            }
        }
    }

private:

    template<TreeFlag flg>
    void update_task_recursive( int batch_idx
                              , const Sequence<Poker_t>& seq
                              , const uint64_t buckets[Poker_t::num_players][Poker_t::num_rounds]
                              , const type::rank_t ranks[Poker_t::num_players]
                              , double chance
                              , double reach[Poker_t::num_players]
                              , double ev[Poker_t::num_players]
                              , double cfr[Poker_t::num_players]
                              ) {

        if constexpr (flg == TreeFlag::UPPER_TREE){
        if (int u = seq.get_id(), round = Game<Poker_t>::round[u]; seq.is_chance() && round == parallel_round){
            std::copy(parallel_chance_variables.evs[u][batch_idx], parallel_chance_variables.evs[u][batch_idx] + Poker_t::num_players, ev);
            for(int p = 0; p < Poker_t::num_players; ++p){
                cfr[p] += 1. * parallel_chance_variables.cfrs[u][p] / sample_batch;
            }
            return;
        }}
        else if (flg == TreeFlag::LOWER_TREE){
            // int u = seq.get_id();
            // if(u == 27)
            //     printf("haha\n");
        }

        if (seq.is_terminal()) {

            double award[Poker_t::num_players];

            if (seq.is_fold()) {

                seq.deal_fold(award);
            }
            else {

                seq.deal_showdown(ranks, award);
            }

            for(int i = 0; i<Poker_t::num_players; ++i){

                double player_minus_probability = 1;
                for(int j=0; j<Poker_t::num_players; ++j){
                    player_minus_probability *= (j == i?1:reach[j]);
                }
                ev[i] = award[i] * player_minus_probability * chance;
            }
        }
        else if (double max_reach = *std::max_element(reach, reach+Poker_t::num_players); max_reach < overall_define::epsilon){

            std::memset(ev, 0, sizeof(double)*Poker_t::num_players);
        }
        else {

            std::memset(ev, 0, sizeof(double)*Poker_t::num_players);

            int u = seq.get_id();
            int player = Game<Poker_t>::whose_turn[u];
            int round = Game<Poker_t>::round[u];
            int action_size = Game<Poker_t>::num_actions[u];

            /* get the probabilty tuple for each player */
            double *probability = new double [action_size];
            strat[player]->get_strategy(seq, buckets[player][round], probability);

            /* first average the strategy for the player */
            // double * average_probability = strat[player]->get_average_probability(seq, buckets[player][round]);
            for(int a=0; a<action_size; ++a) {
                
                strat[player]->accumulate_probability(seq, buckets[player][round], a, reach[player]*probability[a]);
            }

            /* now compute the regret on each of our actions */
            double old_reach = reach[player];
            double* delta_regret = new double[action_size];
            double ret_ev[Poker_t::num_players];
            for(int i=0; i<action_size; ++i) {

                reach[player] = old_reach*probability[i];
                update_task_recursive<flg>(batch_idx, seq.do_action(i), buckets, ranks, chance, reach, ret_ev, cfr);
        
                delta_regret[i] = ret_ev[player];// todo: delta_regret[i] = ev[i] - expected;
                for (int j = 0; j < Poker_t::num_players; ++j){
                    if (j == player)
                        ev[j] += ret_ev[j]*probability[i]; //对于更新玩家i，补充\sigma_i(I,a)
                    else 
                        ev[j] += ret_ev[j]; //对于非更新玩家{-i}，\sigma_{-i}(I,a)在ev中已经包括了 ev = \pi_{-i}(h)\sum_{z}\pi(z|h)u(z)
                }
            }
        
            /* restore reachability value */
            reach[player] = old_reach;
        
            /* subtract off expectation */
            // double * regret = strat[player]->get_regret(seq, buckets[player][round]);
            for(int a=0; a<action_size; ++a) {
                    
                delta_regret[a] -= ev[player];
                strat[player]->accumulate_regret(seq, buckets[player][round], a, delta_regret[a]);
                // regret[i]       += delta_regret[i];
                cfr[player]     += std::max(0., delta_regret[a]);
            }

            //free
            delete[] probability;
            delete[] delta_regret;
        }
    }
};