#pragma once

#include "template/domain_independent/strategy.hpp"

template<typename Poker>
class RegretReach{ //不包括chance的概率
public:
    using Poker_t = Poker;

    PlayerReach(int player, uint64_t bucket_size[Poker_t::num_rounds]);
    ~PlayerReach();

    int get_player() const;
    void compute_reach(const RegretStrategy<Poker_t>& strat, const std::vector<std::array<uint64_t, Poker_t::num_rounds>>& sample_tasks);
    double get_reach(int u, uint64_t isomorphism) const;
    friend void recursive_compute_reach<Poker_t>( PlayerTerminalReach<Poker_t>& player_terminal_reach
                                                , const Strategy<Poker_t>& strategy
                                                , const Sequence<Poker_t>& seq
                                                , double reach
                                                , uint64_t round_buckets[Poker_t::num_rounds]);

protected:
    double* probability[Game<Poker_t>::num_total];
    int player;
};

template<typename Poker_t>
PlayerReach<Poker_t>::PlayerTerminalReach(int player):player(player){
    probability = new double*[Game<Poker_t>::num_terminal];
    std::memset(probability, 0, sizeof(double*) * Game<Poker_t>::num_terminal);

    for(int i = 0; i < Game<Poker_t>::num_terminal; ++i){
        int u = i + Game<Poker_t>::num_internal;

        //拿到terminal节点由哪轮internal节点导出, 
        int round = Game<Poker_t>::round[u];

        uint64_t n = Poker_t::num_hand_isomorphism_round[round];
        probability[i] = new double[n];
        std::fill(probability[i], probability[i] + n, -1.0);
    }
}

template<typename Poker_t>
PlayerTerminalReach<Poker_t>::~PlayerTerminalReach(){
    if (probability) {
        for(int i = 0; i < Game<Poker_t>::num_terminal; ++i){
            
            delete[] probability[i];
        }
        std::memset(probability, 0, sizeof(double*) * Game<Poker_t>::num_terminal);
        
        delete[] probability;
        probability = nullptr;
    }
}

template<typename Poker_t>
void recursive_compute_reach( PlayerTerminalReach<Poker_t>& player_terminal_reach
                            , const Strategy<Poker_t>& strategy
                            , const Sequence<Poker_t>& seq
                            , double reach
                            , uint64_t round_buckets[Poker_t::num_rounds]
                            ){
    int u = seq.get_id();
    int round = Game<Poker_t>::round[u];
    if(seq.is_terminal()) {
        int terminal_id = u-Game<Poker_t>::num_internal;
        if(player_terminal_reach.probability[terminal_id][round_buckets[round]] > 0) {
            assert(player_terminal_reach.probability[terminal_id][round_buckets[round]] == reach);
        }
        else {
            player_terminal_reach.probability[terminal_id][round_buckets[round]] = reach;
        }
    }
    else {
        const double *this_player_tuple = nullptr;
        if(Game<Poker_t>::whose_turn[u] == player_terminal_reach.player)
            this_player_tuple = strategy.get_strategy(seq, round_buckets[round]);
        for(int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
            double action_probability = 1.0;
            if (this_player_tuple) {
                assert(Game<Poker_t>::whose_turn[u] == player_terminal_reach.player);
                action_probability *= this_player_tuple[i];
            }
            recursive_compute_reach(player_terminal_reach, strategy, seq.do_action(i), reach * action_probability, round_buckets);
        }
    }
}

template<typename Poker>
void PlayerTerminalReach<Poker>::compute_reach(const Strategy<Poker_t>& strategy){

    type::card_t hand_card[Poker_t::hand_len[Poker_t::num_rounds-1]];
    uint64_t buckets[Poker_t::num_rounds];
    for (uint64_t i = 0, isomorphism_size = Poker_t::num_hand_isomorphism_round[Poker_t::num_rounds-1]; i<isomorphism_size; ++i){
        Hand<Poker_t>::hand_unindex(i, Poker_t::num_rounds-1, hand_card);
        for(int r = 0, round_size = Poker_t::num_rounds; r < round_size; ++r){
            Hand<Poker_t> hand(hand_card, hand_card + Poker_t::hole_len[round_size-1], r);
            buckets[r] = hand.get_isomorphism();
        }
        recursive_compute_reach<Poker_t>(*this, strategy, Sequence<Poker_t>(0), 1.0, buckets);
    }
}

template<typename Poker>
double PlayerTerminalReach<Poker>::get_reach(int u, uint64_t isomorphism) const {
    return probability[u-Game<Poker_t>::num_internal][isomorphism];
}