#pragma once

#include "template/domain_independent/player_terminal_reach.hpp"
#include "template/domain_independent/strategy.hpp"
#include "template/algorithm/iteration/iteration_util.hpp"
// #include "util/combinatorics.h"
#include <boost/log/trivial.hpp>
#include <forward_list>

template<typename>
class MultithreadExpectValue;

template<typename Poker>
class MultithreadExpectValue{

public:
    using Poker_t = Poker;
    explicit MultithreadExpectValue(unsigned int);
    ~MultithreadExpectValue();

    void compute_evs(const PlayerTerminalReach<Poker_t>* const player_terminal_reaches[Poker_t::num_players], double evs[Poker_t::num_players]);

protected:
    void clear_();
    void init_();
    void malloc_round_(unsigned int round);
    void free_round_(unsigned int round);
    void push_terminal_(unsigned int round);
    void compress_chance_(unsigned int round);
    void push_chance_(unsigned int round);
    void handle_terminal_(const PlayerTerminalReach<Poker_t>* const player_terminal_reaches[Poker_t::num_players], type::card_t holes[Poker_t::num_players][Poker_t::hole_len[Poker_t::num_rounds-1]], type::card_t board[Poker_t::board_len[Poker_t::num_rounds-1]], int round);

protected:
    struct SingleExpectValue;
    friend struct SingleExpectValue;

    const unsigned int thread_num;
    SingleExpectValue **processors;
    std::shared_ptr<tlx::ThreadPool> sp_pool;

    typedef std::atomic<double> async_awards_t[Poker_t::num_players];
    async_awards_t* seqs_awards[Game<Poker_t>::num_total];//[sequence][isomorphism]
    std::vector<int> terminals_round[Poker_t::num_rounds];
    std::vector<int> chances_round[Poker_t::num_rounds];
    
    // debug field
    async_awards_t test_awards[Game<Poker_t>::num_total]{{0,0}};

    struct SingleExpectValue{
    public:
        MultithreadExpectValue& host;

        uint64_t isomorphism_begin_round[Poker_t::num_rounds];
        uint64_t isomorphism_step_round[Poker_t::num_rounds];

        typedef double awards_t[Poker_t::num_players];
        awards_t* seqs_awards[Game<Poker_t>::num_total];
        int unhandled_successor[Game<Poker_t>::num_total];//[Game<Poker_t>::num_total]
        std::forward_list<int> unhandled_seqs_round[Poker_t::num_rounds];
        std::vector<int> tohandle_seqs_round[Poker_t::num_rounds];

    public:
        explicit SingleExpectValue(unsigned int thread_idx, unsigned int thread_num, MultithreadExpectValue* host): host(*host){
            for(int r = 0; r<Poker_t::num_rounds; ++r){
                const uint64_t isomorphism_size = Hand<Poker_t>::get_isomorphism_size(r, 0);
                const uint64_t step = (isomorphism_size + thread_num - 1) / thread_num;
                uint64_t real_step_begin = thread_idx * step;
                uint64_t real_step_end = (real_step_begin + step > isomorphism_size) ? isomorphism_size : real_step_begin + step;
                isomorphism_begin_round[r] = real_step_begin;
                isomorphism_step_round[r] = real_step_end > real_step_begin? real_step_end - real_step_begin : 0;
            }

            std::memset(seqs_awards, 0, sizeof(awards_t*[Game<Poker_t>::num_total]));
        }

        void clear(){

            for(int i = 0; i < Game<Poker_t>::num_total; ++i){
                if(seqs_awards[i])
                    delete[] seqs_awards[i];
            }
            std::memset(seqs_awards, 0, sizeof(awards_t*[Game<Poker_t>::num_total]));

            for(int i = 0; i < Poker_t::num_rounds; ++i){
                unhandled_seqs_round[i].clear();
                tohandle_seqs_round[i].clear();
            }
        }

        void iterate_round(int round){
            do {
                iterate_round_seqs_(round);
            } while(update_and_still_in_round_(round));
        }
    private:
        void iterate_round_seqs_(int round){
            uint64_t isomorphism_step = isomorphism_step_round[round];

            for(const int u : tohandle_seqs_round[round]) {
                if (u < Game<Poker_t>::num_chance && Game<Poker_t>::round[u] == round)
                    continue;
                
                int parent = Game<Poker_t>::parent[u];

                if(!seqs_awards[parent]){
                    seqs_awards[parent] = new awards_t[isomorphism_step];
                    std::memset(seqs_awards[parent], 0, sizeof(awards_t) * isomorphism_step);
                }

                for(uint64_t i = 0; i<isomorphism_step; ++i){
                    for(int p = 0; p<Poker_t::num_players; ++p) {
                        seqs_awards[parent][i][p] += seqs_awards[u][i][p];
                        //debug field;
                        host.test_awards[u][p] += seqs_awards[u][i][p];
                    }
                }

                delete[] seqs_awards[u];
                seqs_awards[u] = nullptr;
            }
        }

        bool update_and_still_in_round_(int round){
            // 层迭代之后，迭代节点的后继就减少了
            for(const int u : tohandle_seqs_round[round]) {
                if (u < Game<Poker_t>::num_chance && Game<Poker_t>::round[u] == round)
                    continue; // 不能减少本轮次前端chance

                int parent = Game<Poker_t>::parent[u];
                --unhandled_successor[parent];
                assert(unhandled_successor[parent]>=0);
            }

            // 本轮次后继节点为0的点进入下次迭代
            tohandle_seqs_round[round].clear();
            for(auto bt = unhandled_seqs_round[round].before_begin(), it = std::next(bt), et = unhandled_seqs_round[round].end(); it!=et; it = std::next(bt)) {
                int u = *it;
                if(unhandled_successor[u] == 0){
                    tohandle_seqs_round[round].push_back(u);
                    unhandled_seqs_round[round].erase_after(bt); 
                }
                else{
                    ++bt;
                }
            }

            // 本轮次该处理的seq没了就结束iterate
            return !tohandle_seqs_round[round].empty();
        }
    };

};

template<typename Poker>
MultithreadExpectValue<Poker>::MultithreadExpectValue(unsigned int thread_num): thread_num(thread_num){
    sp_pool = std::make_shared<tlx::ThreadPool>(thread_num);

    std::memset(seqs_awards, 0, sizeof(async_awards_t*[Game<Poker_t>::num_total]));

    processors = new SingleExpectValue*[thread_num];
    for(unsigned th = 0; th < thread_num; ++th){
        processors[th] = new SingleExpectValue(th, thread_num, this);
    }

    for(int i = 0; i<Game<Poker_t>::num_chance; ++i){
        int round = Game<Poker_t>::round[i];
        chances_round[round].push_back(i);
    }

    // for(int i = Game<Poker_t>::num_chance; i<Game<Poker_t>::num_internal; ++i){
    //     int round = Game<Poker_t>::round[i];
    //     unhandled_seqs_round[round].push_front(i);
    // }

    for(int i = Game<Poker_t>::num_internal; i<Game<Poker_t>::num_total; ++i){
        int round = Game<Poker_t>::round[i];
        terminals_round[round].push_back(i);
    }
}

template<typename Poker>
MultithreadExpectValue<Poker>::~MultithreadExpectValue(){
    clear_();

    for(unsigned th = 0; th < thread_num; ++th){
        delete processors[th];
    }
    delete[] processors;
}

template<typename Poker>
void MultithreadExpectValue<Poker>::clear_(){

    for(int i = 0; i < Game<Poker_t>::num_total; ++i){
        if(seqs_awards[i])
            delete[] seqs_awards[i];
    }
    std::memset(seqs_awards, 0, sizeof(async_awards_t*[Game<Poker_t>::num_total]));

    for(unsigned th = 0; th < thread_num; ++th){
        processors[th]->clear();
    }
}

template<typename Poker>
void MultithreadExpectValue<Poker>::init_(){

    int unhandled_successor[Game<Poker_t>::num_total];
    std::forward_list<int> unhandled_seqs_round[Poker_t::num_rounds];

    for(int i = 0; i<Game<Poker_t>::num_chance; ++i){
        int round = Game<Poker_t>::round[i];
        unhandled_seqs_round[round].push_front(i); //链表天然从后向前
        // chances_round[round].push_back(i);
    }

    for(int i = Game<Poker_t>::num_chance; i<Game<Poker_t>::num_internal; ++i){
        int round = Game<Poker_t>::round[i];
        unhandled_seqs_round[round].push_front(i);
    }

    // for(int i = Game<Poker_t>::num_internal; i<Game<Poker_t>::num_total; ++i){
    //     int round = Game<Poker_t>::round[i];
    //     terminals_round[round].push_back(i);
    // }

    //只对前num_internal初始化动作数，后terminal个都是0
    std::copy(Game<Poker_t>::num_actions, Game<Poker_t>::num_actions + Game<Poker_t>::num_internal, unhandled_successor);
    std::fill(unhandled_successor + Game<Poker_t>::num_internal, unhandled_successor + Game<Poker_t>::num_total, 0);

    for(unsigned th = 0; th<thread_num; ++th){
        std::copy(unhandled_successor, unhandled_successor + Game<Poker_t>::num_total, processors[th]->unhandled_successor);
        for (int r = 0; r < Poker_t::num_rounds; ++r){
            processors[th]->unhandled_seqs_round[r] = unhandled_seqs_round[r];
            processors[th]->tohandle_seqs_round[r] = terminals_round[r];
            if(r < Poker_t::num_rounds-1) { // 对于非最后一街要把下一轮的chance算入当前轮次没解决的seq
                processors[th]->tohandle_seqs_round[r].insert(processors[th]->tohandle_seqs_round[r].end(), chances_round[r+1].begin(), chances_round[r+1].end());
            }
        }
    }
}

template<typename Poker>
void MultithreadExpectValue<Poker>::compute_evs(const PlayerTerminalReach<Poker_t>* const player_terminal_reaches[Poker_t::num_players], double evs[Poker_t::num_players]) {
    clear_();
    init_();

    for (int round = Poker_t::num_rounds-1; round >= 0; --round){

        // 分配当前轮次的后端（terminal）与前端（chance）的空间，其中chance的isomorphism是上一轮次的同构数，round==0时chace的iso==1
        // |----....****| 前,后/ 主线程,分线程/ last iso, iso
        BOOST_LOG_TRIVIAL(debug) << "malloc_round: "<< round;
        malloc_round_(round); // host

        // |........****|
        BOOST_LOG_TRIVIAL(debug) << "deal_terminal: "<< round;
        iteration::multithread_dealing_upto_current_round_hands<Poker_t>(sp_pool, std::bind(&MultithreadExpectValue<Poker_t>::handle_terminal_, this, player_terminal_reaches, std::placeholders::_1, std::placeholders::_2, round), round);

        // atomic数据到trivial数据的映射，仅该轮后端terminal
        // |........****| --> |........****|
        BOOST_LOG_TRIVIAL(debug) << "push_terminal: "<< round;
        push_terminal_(round);

        //                    |........****|**** ==> |****........|....
        BOOST_LOG_TRIVIAL(debug) << "iterate_round: "<< round;
        for(unsigned int th = 0; th < thread_num; ++th){
            sp_pool->enqueue([this, th, round](){
                processors[th]->iterate_round(round);
            });
        }
        sp_pool->loop_until_empty();

        // |----........|.... <-- |****........|....
        BOOST_LOG_TRIVIAL(debug) << "compress_chance: "<< round;
        compress_chance_(round);
        
        // atomic数据到trivial数据的映射，仅该轮前端chance
        // |----........|.... --> |----........|....
        if (round > 0) {
            BOOST_LOG_TRIVIAL(debug) << "push_chance: "<< round;
            push_chance_(round);
        }
        else {
            std::copy(seqs_awards[0][0], seqs_awards[0][0] + Poker_t::num_players, evs);
        }

        // |----....****|
        BOOST_LOG_TRIVIAL(debug) << "free_round: "<< round;
        free_round_(round);
    }

    // debug field;
    FILE * file = fopen("data/debug/multithread_expect_value.csv", "wt");
    for (int i = 0; i<Game<Poker_t>::num_total; ++i){
        fprintf(file, "%d,", i);
        for (int p = 0; p < Poker_t::num_players; ++p){
            fprintf(file, " %lf,", (double)test_awards[i][p]);
        }
        fprintf(file, "\n");
    }
}

template<typename Poker>
void MultithreadExpectValue<Poker>::malloc_round_(unsigned int round){

    uint64_t isomorphism_size = Hand<Poker_t>::get_isomorphism_size(round, 0);
    
    for (unsigned int u : terminals_round[round]){
        seqs_awards[u] = new async_awards_t[isomorphism_size]{{0,0}};
    }

    uint64_t last_isomorphism_size = round == 0? 1 : Hand<Poker_t>::get_isomorphism_size(round-1, 0);
    for (unsigned int u : chances_round[round]){
        seqs_awards[u] = new async_awards_t[last_isomorphism_size]{{0,0}};
    }
}

template<typename Poker>
void MultithreadExpectValue<Poker>::handle_terminal_(const PlayerTerminalReach<Poker_t>* const player_terminal_reaches[Poker_t::num_players], type::card_t holes[Poker_t::num_players][Poker_t::hole_len[Poker_t::num_rounds-1]], type::card_t board[Poker_t::board_len[Poker_t::num_rounds-1]], int round) {
    Hand<Poker_t> temp_hand(holes[0], board, round);
    uint64_t round_buckets[Poker_t::num_players];
    round_buckets[0] = temp_hand.get_hand_isomorphism(0);
    for(int i = 1; i<Poker_t::num_players; ++i){
        temp_hand.change_hole(holes[i]);
        round_buckets[i] = temp_hand.get_hand_isomorphism(0);
    }
//////////////////////////////////////////////////////////////////
    int ranks[Poker_t::num_players];
    Evaluator<Poker_t>::evaluate_ranks(holes, board, ranks);

    double awards[Poker_t::num_players];
    for(const int u : terminals_round[round]){
        Sequence<Poker_t> seq(u);
        assert(seq.is_terminal());
        assert(Game<Poker_t>::round[u] == round);
        if(seq.is_fold()){
            seq.deal_fold(awards);
        }
        else {
            seq.deal_showdown(ranks, awards);
        }

        double pr = 1.;
        for (int i = 0; i < Poker_t::num_players; ++i){
            pr *= player_terminal_reaches[i]->get_reach(u, round_buckets[i]);
        }

        for (int i = 0; i < Poker_t::num_players; ++i){
            seqs_awards[u][round_buckets[i]][i] += pr * awards[i];
        }
    }
}

template<typename Poker>
void MultithreadExpectValue<Poker>::push_terminal_(unsigned int round){

    for (unsigned th = 0; th < thread_num; ++th){
        
        sp_pool->enqueue([this, th, round](){
            uint64_t isomorphism_step = processors[th]->isomorphism_step_round[round];
            uint64_t isomorphism_begin = processors[th]->isomorphism_begin_round[round];

            // 先处理terminal，这些都在tohandle里面
            for(int u : terminals_round[round]){
                processors[th]->seqs_awards[u] = new SingleExpectValue::awards_t[isomorphism_step];
                for(uint64_t i = 0; i<isomorphism_step; ++i) {
                    for(int p = 0; p<Poker_t::num_players; ++p){
                        processors[th]->seqs_awards[u][i][p] = seqs_awards[u][i+isomorphism_begin][p];
                    }
                }
            }
        });
    }
    sp_pool->loop_until_empty();
}

template<typename Poker>
void MultithreadExpectValue<Poker>::compress_chance_(unsigned int round){


    for(unsigned int th = 0; th < thread_num; ++th) {
        sp_pool->enqueue([this, round, th](){
            type::card_t hand_c[Poker_t::hand_len[Poker_t::num_rounds-1]];

            for ( uint64_t i = 0, iso_begin =  processors[th]->isomorphism_begin_round[round], iso_step = processors[th]->isomorphism_step_round[round]
                ; i < iso_step; ++i){
                //计算上一层的idx
                uint64_t isomorphism = iso_begin + i;
                uint64_t last_isomorphism;
                if(round == 0) {
                    last_isomorphism = 0;
                }
                else{
                    int last_round = round - 1;
                    Hand<Poker_t>::hand_unisomorphism(isomorphism, round, 0, hand_c);
                    Hand<Poker_t> last_hand(/*hole*/hand_c, /*board*/hand_c + Poker_t::hole_len[round], last_round);
                    last_isomorphism = last_hand.get_hand_isomorphism(0);
                }

                for(int u : chances_round[round]) {
                    for(int p = 0; p<Poker_t::num_players; ++p) {
                        seqs_awards[u][last_isomorphism][p] += processors[th]->seqs_awards[u][i][p] / Poker_t::deal_combine_num_round[round];
                    }
                }
            }

            //释放下一层空间
            for(int u : chances_round[round]) {

                delete[] processors[th]->seqs_awards[u];
                processors[th]->seqs_awards[u] = nullptr;
            }
        });
    }
    sp_pool->loop_until_empty();

}

template<typename Poker>
void MultithreadExpectValue<Poker>::push_chance_(unsigned int round){
    assert(round > 0);

    for (unsigned th = 0; th < thread_num; ++th){
        
        sp_pool->enqueue([this, th, round](){
            int last_round = round - 1;
            uint64_t isomorphism_step = processors[th]->isomorphism_step_round[last_round];
            uint64_t isomorphism_begin = processors[th]->isomorphism_begin_round[last_round];

            // 再处理extract from chance
            for(int u : chances_round[round]){
                processors[th]->seqs_awards[u] = new SingleExpectValue::awards_t[isomorphism_step];
                for(uint64_t i = 0; i<isomorphism_step; ++i) {
                    for(int p = 0; p<Poker_t::num_players; ++p){
                        processors[th]->seqs_awards[u][i][p] = seqs_awards[u][i+isomorphism_begin][p];
                    }
                }
            }

        });
    }
    sp_pool->loop_until_empty();
}

template<typename Poker>
void MultithreadExpectValue<Poker>::free_round_(unsigned int round){
    
    for (unsigned int i : terminals_round[round]){
        delete[] seqs_awards[i];
        seqs_awards[i] = nullptr;
    }


    for (unsigned int i : chances_round[round]){
        delete[] seqs_awards[i];
        seqs_awards[i] = nullptr;
    }
}