
template<typename Poker>
class MultithreadCS_MCCFR{
public:
    using Poker_t = Poker;


    explicit MultithreadCS_MCCFR( uint64_t sample_batch
                                , int thread_num
                                , std::shared_ptr<RegretStrategy<Poker_t, /* atomic_update */ true>> strat[Poker_t::num_players]) : sample_batch(sample_batch)
                                                                                                                                  , thread_num(thread_num)
                                                                                                                                  , pool(this->thread_num) {
        std::copy(strat, strat + Poker_t::num_players, this->strat);

        processors = new SingleCS_MCCFR*[thread_num];
        for(int th = 0; th < thread_num; ++th){
            processors[th] = new SingleCS_MCCFR(th, thread_num, this);
        }
    }

    void reset(uint64_t sample_batch, int thread_num, std::shared_ptr<RegretStrategy<Poker_t, /* atomic_update */ true>> strat[Poker_t::num_players]){
        this->sample_batch = sample_batch;
        this->thread_num = thread_num;
        pool = tlx::ThreadPool(this->thread_num);

        std::copy(strat, strat + Poker_t::num_players, this->strat);

        if (processors) {
            delete[] processors;
        }
        processors = new SingleCS_MCCFR*[thread_num];
        for(unsigned th = 0; th < thread_num; ++th){
            processors[th] = new SingleCS_MCCFR(th, thread_num, this);
        }
    }

    void train( const uint64_t (*buckets)[Poker_t::num_players][Poker_t::num_rounds] // batch_step
              , const type::rank_t (*ranks)[Poker_t::num_players]
              , const double reach[Poker_t::num_players]
              , double ev[Poker_t::num_players]
              , double cfv[Poker_t::num_players]){

        double (*evs)[Poker_t::num_players] = new double[thread_num][Poker_t::num_players]{{0,0}};
        double (*cfvs)[Poker_t::num_players] = new double[thread_num][Poker_t::num_players]{{0,0}};

        std::memset(ev, 0, sizeof(double) * Poker_t::num_players);
        std::memset(cfv, 0, sizeof(double) * Poker_t::num_players);
        for (int th = 0; th < thread_num; ++th) {
            auto func = std::bind(&SingleCS_MCCFR::train, processors[th], buckets, ranks, reach, evs[th], cfvs[th]);
            pool.enqueue(func);
        }

        pool.loop_until_empty();

        for (int p = 0; p < Poker_t::num_players; ++p){
            strat[p]->update_delta();
        }

        for (int th = 0; th < thread_num; ++th){
            for (int p = 0; p < Poker_t::num_players; ++p){
                ev[p] += evs[th][p];
                cfv[p] += cfvs[th][p];
            }
        }

        delete[] evs;
        delete[] cfvs;
    }

    ~MultithreadCS_MCCFR() {
        if (processors) {
            for(int th = 0; th < thread_num; ++th){
                delete processors[th];
            }
            delete[] processors;
            processors = nullptr;
        }
    }

protected:

protected:
    struct SingleCS_MCCFR;
    friend struct SingleCS_MCCFR;

    uint64_t sample_batch;
    int thread_num;
    tlx::ThreadPool pool;
    std::shared_ptr<RegretStrategy<Poker_t, /* atomic_update */ true>> strat[Poker_t::num_players];
    SingleCS_MCCFR **processors = nullptr;

    struct SingleCS_MCCFR{
    public:
        MultithreadCS_MCCFR& host;
        
        uint64_t batch_begin;
        uint64_t batch_end;

        uint64_t isomorphism_begin_round[Poker_t::num_rounds];
        uint64_t isomorphism_step_round[Poker_t::num_rounds];

        explicit SingleCS_MCCFR(unsigned int thread_idx, unsigned int thread_num, MultithreadCS_MCCFR* host): host(*host){
            const uint64_t step = (this->host.sample_batch + thread_num - 1) / thread_num;
            batch_begin = thread_idx * step;
            batch_end = (batch_begin + step > this->host.sample_batch) ? this->host.sample_batch : batch_begin + step;
        }
        void update_accumulative_regret();
        void train( const uint64_t (*buckets)[Poker_t::num_players][Poker_t::num_rounds] // batch_step
                  , const type::rank_t (*ranks)[Poker_t::num_players]
                  , const double reach[Poker_t::num_players]
                  , double ev[Poker_t::num_players]
                  , double cfv[Poker_t::num_players]){
            // const initial
            double reach_[Poker_t::num_players];
            //zero initial
            std::fill(cfv, cfv + Poker_t::num_players, 0.);
            //non-intial
            double ev_[Poker_t::num_players];
            for (uint64_t b = batch_begin; b < batch_end; ++b){
                std::copy(reach, reach + Poker_t::num_players, reach_);
                update_recursive(Sequence<Poker_t>(0), buckets[b], ranks[b], 1., reach_, ev_, cfv);
                for (int p = 0; p<Poker_t::num_players; ++p){
                    ev[p] += ev_[p];
                }
            }
        }
        void update_recursive( const Sequence<Poker_t>& seq
                             , const uint64_t buckets[Poker_t::num_players][Poker_t::num_rounds] // batch_step
                             , const type::rank_t ranks[Poker_t::num_players] // batch_step
                             , double chance // double chance
                             , double reach[Poker_t::num_players] // batch_step
                             , double ev[Poker_t::num_players] // batch_step
                             , double cfr[Poker_t::num_players]
                             ) {
            if (seq.is_terminal()) {

                double award[Poker_t::num_players];

                if (seq.is_fold()) {

                    seq.deal_fold(award);
                }
                else {

                    seq.deal_showdown(ranks, award);
                }

                for(int i = 0; i<Poker_t::num_players; ++i){

                    double player_minus_probability = 1;
                    for(int j=0; j<Poker_t::num_players; ++j){
                        player_minus_probability *= (j == i?1:reach[j]);
                    }
                    ev[i] = award[i] * player_minus_probability * chance;
                }
            }
            else if (double max_reach = *std::max_element(reach, reach+Poker_t::num_players); max_reach < overall_define::epsilon){

                std::memset(ev, 0, sizeof(double) * Poker_t::num_players);
            }
            else {

                std::memset(ev, 0, sizeof(double) * Poker_t::num_players);

                int u = seq.get_id();
                int player = Game<Poker_t>::whose_turn[u];
                int round = Game<Poker_t>::round[u];
                int action_size = Game<Poker_t>::num_actions[u];

                /* get the probabilty tuple for each player */
                double *probability = new double [action_size];
                host.strat[player]->get_strategy(seq, buckets[player][round], probability);

                /* first average the strategy for the player */
                // auto* /* double* or std::atomic<double>* */ average_probability = strat[player]->get_average_probability(seq, buckets[player][round]);
                for(int a=0; a<action_size; ++a) {
                    // average_probability[i] += reach[player]*probability[i];
                    host.strat[player]->accumulate_probability(seq, buckets[player][round], a, reach[player]*probability[a]);
                }

                /* now compute the regret on each of our actions */
                double old_reach = reach[player];
                double* delta_regret = new double[action_size];
                double ret_ev[Poker_t::num_players];
                for(int i=0; i<action_size; ++i) {
                        
                    reach[player] = old_reach*probability[i];
                    update_recursive(seq.do_action(i), buckets, ranks, chance, reach, ret_ev, cfr);
            
                    delta_regret[i] = ret_ev[player];// todo: delta_regret[i] = ev[i] - expected;
                    for (int j = 0; j < Poker_t::num_players; ++j){
                        if (j == player)
                            ev[j] += ret_ev[j]*probability[i]; //对于更新玩家i，补充\sigma_i(I,a)
                        else 
                            ev[j] += ret_ev[j]; //对于非更新玩家{-i}，\sigma_{-i}(I,a)在ev中已经包括了 ev = \pi_{-i}(h)\sum_{z}\pi(z|h)u(z)
                    }
                }
            
                /* restore reachability value */
                reach[player] = old_reach;
            
                /* subtract off expectation */
                // auto* /* double* or std::atomic<double>* */ regret = strat[player]->get_regret(seq, buckets[player][round]);
                for(int a=0; a<action_size; ++a) {
                        
                    delta_regret[a] -= ev[player];
                    host.strat[player]->accumulate_regret(seq, buckets[player][round], a, delta_regret[a]);
                    // regret[i]       += delta_regret[i];
                    cfr[player]     += std::max(0., delta_regret[a]);
                }

                //free
                delete[] probability;
                delete[] delta_regret;
            }
        }

    };
};