#pragma once

#include "template/domain_independent/regret_strategy.hpp"
#include "template/domain_independent/abstraction.h"
#include "template/domain_dependent/hand_base.hpp"

template<typename Poker>
class MultithreadPlayerTerminalReach{ //不包括chance的概率
public:
    using Poker_t = Poker;
    explicit MultithreadPlayerTerminalReach(unsigned int thread_num);
    ~MultithreadPlayerTerminalReach();
    int get_player() const;
    double get_reach(int u, uint64_t isomorphism) const;

    void register_player(int p);
    void free_player();
    template<bool atomic_update>
    void compute_reach(const RegretStrategy<Poker, atomic_update> &regret_strategy, const Abstraction<Poker>& abstraction);

private:
    void malloc_round_(unsigned int round);
    void free_round_(unsigned int round);

protected:
    int player = -1;
    std::mutex mtx;

    struct SinglethreadPlayerTerminalReach;
    friend struct SinglethreadPlayerTerminalReach;

    const unsigned int thread_num;
    SinglethreadPlayerTerminalReach** processors;
    std::shared_ptr<tlx::ThreadPool> sp_pool;

    double* terminal_reaches[Game<Poker_t>::num_terminal]{nullptr};
    double* chance_reaches[Game<Poker_t>::num_chance]{nullptr};
    std::vector<int> terminals_round[Poker_t::num_rounds];
    std::vector<int> chances_round[Poker_t::num_rounds];

    struct SinglethreadPlayerTerminalReach{
    public:
        const MultithreadPlayerTerminalReach& host;

        uint64_t isomorphism_begin_round[Poker_t::num_rounds];
        uint64_t isomorphism_step_round[Poker_t::num_rounds];
    
    public:
        explicit SinglethreadPlayerTerminalReach(unsigned int thread_idx, unsigned int thread_num, MultithreadPlayerTerminalReach* host): host(*host){
            for(int r = 0; r<Poker_t::num_rounds; ++r){
                const uint64_t isomorphism_size = Hand<Poker_t>::get_isomorphism_size(r, 0);
                const uint64_t step = (isomorphism_size + thread_num - 1) / thread_num;
                uint64_t real_step_begin = thread_idx * step;
                uint64_t real_step_end = (real_step_begin + step > isomorphism_size) ? isomorphism_size : real_step_begin + step;
                isomorphism_begin_round[r] = real_step_begin;
                isomorphism_step_round[r] = real_step_end > real_step_begin? real_step_end - real_step_begin : 0;
            }
        }

        template<bool atomic_update>
        void recursive_round(int round, const RegretStrategy<Poker, atomic_update> &regret_strategy, const Abstraction<Poker>& abstraction) {
            double *reaches = new double[isomorphism_step_round[round]];
            for(const int u : host.chances_round[round]) {

                type::card_t *hand = new type::card_t[Poker_t::hand_len[round]];
                if (round > 0) {
                    for (uint64_t priso_offset = 0; priso_offset < isomorphism_step_round[round]; ++priso_offset) {
                        Hand<Poker_t>::hand_unisomorphism(priso_offset+isomorphism_begin_round[round], round, 0, hand);
                        uint64_t last_hand_iso = Hand<Poker_t>(hand, hand + Poker_t::hole_len[round], round-1).get_hand_isomorphism(0);
                        reaches[priso_offset] = host.chance_reaches[u][last_hand_iso];
                    }
                } else {
                    uint64_t last_hand_iso = 0;
                    for (uint64_t priso_offset = 0; priso_offset < isomorphism_step_round[round]; ++priso_offset) {
                        reaches[priso_offset] = host.chance_reaches[u][last_hand_iso];
                    }
                }
                delete[] hand;

                recursive_compute_reaches(Sequence<Poker_t>(u), round, reaches, regret_strategy, abstraction);
            }
            delete[] reaches;
        }

        template<bool atomic_update>
        void recursive_compute_reaches(const Sequence<Poker_t>& seq, int round, const double* reaches, const RegretStrategy<Poker, atomic_update>& regret_strategy, const Abstraction<Poker>& abstraction) {
            int u = seq.get_id();
            int r = Game<Poker_t>::round[u];

            if (seq.is_terminal()){
                int terminal_id = u-Game<Poker_t>::num_internal;
                std::copy(reaches, reaches + isomorphism_step_round[round], host.terminal_reaches[terminal_id]+isomorphism_begin_round[round]);
                // 完成
            }
            else if (r > round) {
                assert(u < Game<Poker_t>::num_chance);
                std::copy(reaches, reaches + isomorphism_step_round[round], host.chance_reaches[u]+isomorphism_begin_round[round]);
                // 完成
            }
            else if (host.player == Game<Poker_t>::whose_turn[u]){
                // 以[action][isomorphism]而非[isomorphism][action]构建reaches
                double** action_seqs_reaches = new double*[Game<Poker_t>::num_actions[u]];
                for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
                    action_seqs_reaches[i] = new double[isomorphism_step_round[round]];
                    std::copy(reaches, reaches + isomorphism_step_round[round], action_seqs_reaches[i]);
                }

                // 
                double* strategy_tuple = new double[Game<Poker_t>::num_actions[u]];
                for ( uint64_t priso_offset = 0; priso_offset < isomorphism_step_round[round]; ++priso_offset) {
                    uint64_t abstraction_bucket = abstraction.abstract_view(priso_offset + isomorphism_begin_round[round], round);
                    regret_strategy.get_average_strategy(seq, abstraction_bucket, strategy_tuple);
                    for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
                        action_seqs_reaches[i][priso_offset] *= strategy_tuple[i];
                    }
                }
                delete[] strategy_tuple;

                for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
                    recursive_compute_reaches(seq.do_action(i), round, action_seqs_reaches[i], regret_strategy, abstraction);
                    delete[] action_seqs_reaches[i];
                }
                /// delete
                delete[] action_seqs_reaches;
                // 完成
            }
            else {
                for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
                    recursive_compute_reaches(seq.do_action(i), round, reaches, regret_strategy, abstraction);
                }
                // 完成
            }
        }
    };
};

template<typename Poker>
MultithreadPlayerTerminalReach<Poker>::MultithreadPlayerTerminalReach(unsigned int thread_num): thread_num(thread_num) {
    sp_pool = std::make_shared<tlx::ThreadPool>(thread_num);
    std::memset(terminal_reaches, 0, sizeof(double*) * Game<Poker_t>::num_terminal);

    processors = new SinglethreadPlayerTerminalReach*[thread_num];
    for(unsigned th = 0; th < thread_num; ++th){
        processors[th] = new SinglethreadPlayerTerminalReach(th, thread_num, this);
    }

    for(int i = 0; i<Game<Poker_t>::num_chance; ++i){
        int round = Game<Poker_t>::round[i];
        chances_round[round].push_back(i);
    }

    for(int i = Game<Poker_t>::num_internal; i<Game<Poker_t>::num_total; ++i){
        int round = Game<Poker_t>::round[i];
        terminals_round[round].push_back(i);
    }
}

template<typename Poker>
MultithreadPlayerTerminalReach<Poker>::~MultithreadPlayerTerminalReach() {
    if (terminal_reaches) {
        for(int i = 0; i < Game<Poker_t>::num_terminal; ++i){
            delete[] terminal_reaches[i];
        }
        std::memset(terminal_reaches, 0, sizeof(double*) * Game<Poker_t>::num_terminal);
    }

    for(unsigned th = 0; th < thread_num; ++th){
        delete processors[th];
    }
    delete[] processors;
}

template<typename Poker>
int MultithreadPlayerTerminalReach<Poker>::get_player() const {
    assert(player >= 0);
    return player;
}

template<typename Poker>
double MultithreadPlayerTerminalReach<Poker>::get_reach(int u, uint64_t isomorphism) const {
    return terminal_reaches[u-Game<Poker_t>::num_internal][isomorphism];
}

template<typename Poker>
void MultithreadPlayerTerminalReach<Poker>::register_player(int p) {
    assert(p >= 0);
    std::lock_guard<std::mutex> lock(mtx);
    if (player < 0) {
        for(unsigned int th = 0; th < thread_num; ++th){
            sp_pool->enqueue([this, th](){
                for(int i = th; i < Game<Poker_t>::num_terminal; i += this->thread_num){
                    int r = Game<Poker_t>::round[i + Game<Poker_t>::num_internal];
                    uint64_t iso_size = Hand<Poker_t>::get_isomorphism_size(r, 0);
                    terminal_reaches[i] = new double[iso_size];
                }
            });
        }
        sp_pool->loop_until_empty();
    }
    player = p;
}

template<typename Poker>
void MultithreadPlayerTerminalReach<Poker>::free_player() {
    assert(player >= 0);
    std::lock_guard<std::mutex> lock(mtx);
    player = -1;
    ///////////////////////
    for(unsigned int th = 0; th < thread_num; ++th){
        sp_pool->enqueue([this, th](){
            for(int i = th; i < Game<Poker_t>::num_terminal; i += this->thread_num){
                delete terminal_reaches[i];
                terminal_reaches[i] = nullptr;
            }
        });
    }
    sp_pool->loop_until_empty();
}

template<typename Poker>
template<bool atomic_update>
void MultithreadPlayerTerminalReach<Poker>::compute_reach(const RegretStrategy<Poker, atomic_update>& regret_strategy, const Abstraction<Poker>& abstraction) {
    register_player(regret_strategy.get_player());

    malloc_round_(0);
    for (int r = 0; r < Poker_t::num_rounds; ++r) {
        // 为下轮次的前端（chance）分配空间，其中chance的isomorphism是当前轮次的同构数
        if (r != Poker_t::num_rounds-1) {
            malloc_round_(r + 1); // host
        }

        for(unsigned int th = 0; th < thread_num; ++th){
            sp_pool->enqueue([this, th, r, &regret_strategy, &abstraction](){
                processors[th]->recursive_round(r, regret_strategy, abstraction);
            });
        }
        sp_pool->loop_until_empty();

        // 往后捣鼓
        free_round_(r);
    }
}

template<typename Poker>
void MultithreadPlayerTerminalReach<Poker>::malloc_round_(unsigned int round) {
    uint64_t last_isomorphism_size = round == 0? 1 : Hand<Poker_t>::get_isomorphism_size(round-1, 0);
    for (unsigned int u : chances_round[round]){
        chance_reaches[u] = new double[last_isomorphism_size]{1.};
    }
}

template<typename Poker>
void MultithreadPlayerTerminalReach<Poker>::free_round_(unsigned int round) {
    for (unsigned int u : chances_round[round]){
        delete[] chance_reaches[u];
    }
}