#pragma once
#include "overall_define.h"
#include "template/domain_independent/regret_strategy.hpp"
#include "template/domain_independent/sequence.hpp"
#include <sstream>

//////////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Poker>
class Strategy{
public:
    using Poker_t = Poker;

    void register_player(int p);
    void free_player(int p);

    void set_strategy(const Sequence<Poker_t>& seq, const Hand<Poker_t>& hand, const double * tuple);
    void set_strategy(const Sequence<Poker_t>& seq, uint64_t hand_isomorphism, const double * tuple);
    const double*  get_strategy(const Sequence<Poker_t>& seq, const Hand<Poker_t>& hand) const ;
    const double*  get_strategy(const Sequence<Poker_t>& seq, uint64_t hand_isomorphism) const ;
    template<bool atomic_update>
    void recover_average_strategy(const Abstraction<Poker_t>& abstraction, const RegretStrategy<Poker_t, atomic_update>& regre_strat);

    void save_txt(FILE* file) const ;
    void save_binary(FILE* file) const;

    void load_txt(const char* file_path);
    void load_binary(const char* file_path);

    ~Strategy();
private:
    double **tuples[Game<Poker_t>::num_internal]{nullptr};
    bool player_registered[Poker_t::num_players]{false};
    std::mutex mtx[Poker_t::num_players]{};
};

#include <cstring>
//////////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Poker_t>
void Strategy<Poker_t>::register_player(int p) {
    assert(p>=0);
    std::lock_guard<std::mutex> lock(mtx[p]);
    if (player_registered[p])
        return;
    player_registered[p] = true;

    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        if(Game<Poker_t>::whose_turn[i] != p)
            continue;

        uint64_t isomorphism = Hand<Poker_t>::get_isomorphism_size(Game<Poker_t>::round[i], 0);
        tuples[i] = new double*[isomorphism];
        for (uint64_t j = 0; j < isomorphism; ++j) {
            tuples[i][j] = new double[Game<Poker_t>::num_actions[i]]{0}; // 初始化的时候和是0， 其他时候和为1
        }
    }
}

template<typename Poker_t>
void Strategy<Poker_t>::free_player(int p) {
    assert(p>=0);
    std::lock_guard<std::mutex> lock(mtx[p]);
    if (!player_registered[p])
        return;
    player_registered[p] = false;

    for(int i = 0; i<Game<Poker_t>::num_internal; ++i){
        if(Game<Poker_t>::whose_turn[i] != p)
            continue;

        uint64_t isomorphism = Hand<Poker_t>::get_isomorphism_size(Game<Poker_t>::round[i], 0);
        for (uint64_t j = 0; j < isomorphism; ++j) {
            delete[] tuples[i][j];
        }
        delete[] tuples[i];
        tuples[i] = nullptr;
    }
}

template<typename Poker_t>
const double*  Strategy<Poker_t>::get_strategy(const Sequence<Poker_t>& seq, const Hand<Poker_t>& hand) const {
    int u = seq.get_id();
    assert(Game<Poker_t>::round[u] == hand.get_round());
    int v = hand.get_isomorphism();

    if(!tuples || !tuples[u] || !tuples[u][v])
        return nullptr;
    return tuples[u][v];
}

template<typename Poker_t>
const double*  Strategy<Poker_t>::get_strategy(const Sequence<Poker_t>& seq, uint64_t hand_isomorphism) const {
    int u = seq.get_id();
    assert(hand_isomorphism < Hand<Poker_t>::get_isomorphism_size(Game<Poker_t>::round[seq.get_id()], 0));

    if(!tuples || !tuples[u] || !tuples[u][hand_isomorphism])
        return nullptr;
    return tuples[u][hand_isomorphism];
}

template<typename Poker_t>
void Strategy<Poker_t>::set_strategy(const Sequence<Poker_t>& seq, const Hand<Poker_t>& hand, const double * tuple) {

    int u = seq.get_id();
    assert(Game<Poker_t>::round[u] == hand.get_round());

    uint64_t hand_isomorphism = hand.get_hand_isomorphism(0);

    double sum = std::accumulate(tuple, tuple + Game<Poker_t>::num_actions[u], 0.);

    for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
        tuples[u][hand_isomorphism][i] = tuple[i]/sum;
    }
}

template<typename Poker_t>
void Strategy<Poker_t>::set_strategy(const Sequence<Poker_t>& seq, uint64_t hand_isomorphism, const double *tuple){

    int u = seq.get_id();
    assert(hand_isomorphism < Hand<Poker_t>::get_isomorphism_size(Game<Poker_t>::round[u], 0));

    double sum = std::accumulate(tuple, tuple + Game<Poker_t>::num_actions[u], 0.);

    for (int i = 0; i < Game<Poker_t>::num_actions[u]; ++i) {
        tuples[u][hand_isomorphism][i] = tuple[i]/sum;
    }
}

template<typename Poker_t, bool atomic_update>
void recover_average_strategy(Strategy<Poker_t>& strat, const Sequence<Poker_t>& seq, const Abstraction<Poker_t>& abstraction, const RegretStrategy<Poker_t, atomic_update>& regret_strat) {
    int u = seq.get_id();

    if(seq.is_terminal()){
        /* do nothing at terminal */
    }
    else {
        int round = Game<Poker_t>::round[u];
        int action_size = Game<Poker_t>::num_actions[u];

        if (Game<Poker_t>::whose_turn[u] == regret_strat.get_player()) {

            double *probability = new double[action_size];
            for(uint64_t priso = 0, isomorphism_size = Hand<Poker_t>::get_isomorphism_size(round, 0); priso < isomorphism_size; ++priso) {

                regret_strat.get_average_strategy(seq, abstraction.abstract_view(priso, round), probability);
                strat.set_strategy(seq, priso, probability);
            }
            delete[] probability;
        }

        /* recurse */
        for(int i = 0; i < action_size; ++i) {
        
            recover_average_strategy( strat
                                    , seq.do_action(i)
                                    , abstraction
                                    , regret_strat
                                    );
        }
    }
}

template<typename Poker>
template<bool atomic_update>
void Strategy<Poker>::recover_average_strategy(const Abstraction<Poker_t>& abstraction, const RegretStrategy<Poker_t, atomic_update>& regret_strat) {
    
    int player = regret_strat.get_player();
    register_player(player);
    std::lock_guard<std::mutex> lock(mtx[player]);
    ::recover_average_strategy<Poker_t>( *static_cast<Strategy<Poker_t>*>(this)
                                       , Sequence<Poker_t>(0)
                                       , abstraction
                                       , regret_strat
                                       );
}

// template<typename Poker>
// void Strategy<Poker>::save_txt(FILE* file) const{ 
//     assert(file);

//     if (tuples) {
//         for (int i = 0; i < Game<Poker_t>::num_internal; ++i) {

//             if (tuples[i]) {

//                 Sequence<Poker_t> seq(i);
//                 std::string seq_str = seq.to_string();
//                 int action_size = Game<Poker_t>::num_actions[i];

//                 int round = Game<Poker_t>::round[i];
//                 int hand_len = Poker_t::hand_len[round];
//                 type::card_t *hand = new type::card_t[hand_len];

//                 for(int j = 0, isomorphism_size = Poker_t::num_hand_isomorphism_round[round]; j<isomorphism_size; ++j){
//                     if (!tuples[i][j]) 
//                         continue;
//                     /* 打印牌同构id */
//                     fprintf(file, "%d:", j);

//                     /* 打印状态id */
//                     fprintf(file, "%d:", i);
    
//                     /* 打印牌 */
//                     Hand<Poker_t>::hand_unindex(j, round, hand);
//                     for(int k = 0; k<hand_len; ++k)
//                         fprintf(file, "%s", card_to_string<Poker_t>(hand[k]).c_str());
                    
//                     /* 打印动作 */
//                     fprintf(file, ":%s:", seq_str.c_str());

//                     /* 打印动作概率 */
//                     for (int k = 0; k < action_size; ++k) {
//                         int next_seq = Game<Poker_t>::transition[i][k];
//                         char action = Game<Poker_t>::action_result_from[next_seq];
//                         if((action == 'k' || action == 'c') && k == 0){
//                             fprintf(file, " %.09lf", 0.0);
//                         }
//                         fprintf(file, " %.09lf", tuples[i][j][k]);
//                         if((action == 'k' || action == 'c') && k == action_size-1){
//                             fprintf(file, " %.09lf", 0.0);
//                         }
//                     }
//                     fprintf(file, "\n");
//                 }

//                 delete[] hand;
//             }
//         }
//     }
// }

template<typename Poker>
void Strategy<Poker>::save_txt(FILE* file) const {
    assert(file);
    fprintf(file, "player:");
    for(int i = 0; i<Poker_t::num_players; ++i){
        if(player_registered[i])
            fprintf(file, "\t%d", i);
    }
    fprintf(file, "\n");

    for (int i = 0; i < Game<Poker_t>::num_internal; ++i) {

        if (tuples[i]) {

            Sequence<Poker_t> seq(i);
            std::string seq_str = seq.to_string();
            int action_size = Game<Poker_t>::num_actions[i];

            int round = Game<Poker_t>::round[i];
            int hand_len = Poker_t::hand_len[round];
            type::card_t *hand = new type::card_t[hand_len];

            for(uint64_t j = 0, isomorphism_size = Hand<Poker_t>::get_isomorphism_size(round, 0); j<isomorphism_size; ++j){
                if (!tuples[i][j]) 
                    continue;
                /* 打印牌同构id */
                fprintf(file, "%lu:", j);

                /* 打印状态id */
                fprintf(file, "%d:", i);

                /* 打印牌 */
                Hand<Poker_t>::hand_unisomorphism(j, round, 0, hand);
                for(int k = 0; k<hand_len; ++k)
                    fprintf(file, "%s", card_to_string<Poker_t>(hand[k]).c_str());
                
                /* 打印动作 */
                fprintf(file, ":%s:", seq_str.c_str());

                /* 打印动作概率 */
                for (int k = 0; k < action_size; ++k) {
                    int next_seq = Game<Poker_t>::transition[i][k];
                    fprintf(file, " %.09lf@%c", tuples[i][j][k], Game<Poker_t>::action_result_from[next_seq]);
                }
                fprintf(file, "\n");
            }

            delete[] hand;
        }
    }
}

template<typename Poker>
void Strategy<Poker>::load_txt(const char* file_path){

    assert(file_path);

    FILE * file = fopen(file_path, "rt");
    assert(file);

    bool players[Poker_t::num_players]{false};

    for(char line[1024]; fgets(line, 1024, file);){
        if(line[0] != '#'){
            int player;
            assert(std::equal(line, line+7, "player:"));
            std::stringstream sstrm(line+7);
            while(sstrm >> player){
                players[player] = true;
            }
            break;
        }
    }

    for(int i = 0; i<Poker_t::num_players; ++i){
        if(players[i])
            register_player(i);
    }

    double *tuple;

    uint64_t isomorphism_idx;
    int sequence_idx;
    char hand[1024], betting_sequence[1024];
    for(char line[1024]; fgets(line, 1024, file);){

        /* check if this line is a comment */
        if(line[0] == '#')
            continue;

        /* parse the line */
        int j;
        auto sscanf_ret = sscanf(line, "%lu:%d:%[^:]:%[^:]:%n",&isomorphism_idx, &sequence_idx, hand, betting_sequence, &j);
        assert(sscanf_ret >= 4);
        // printf("%d:%d:%s:%s:%d\n", isomorphism_idx, sequence_idx, std::string(hand).c_str(),std::string( betting_sequence).c_str(),j);

        /* parse betting sequence */
        Sequence<Poker_t> seq(betting_sequence);

        /* parse hand */
        Hand<Poker_t> h = Hand<Poker_t>::from_string(std::string(hand));

        /* read in the tuple */
        int action_size = Game<Poker_t>::num_actions[seq.get_id()];
        tuple = new double[action_size];
        char action;
        for (int i = 0; i<action_size; ++i){
            int k;
            auto sscanf_ret = sscanf(line+j, "%lf@%c%n", &tuple[i], &action, &k);
            assert(sscanf_ret >= 1);
            j += k;
        }
        set_strategy(seq, h, tuple);
        assert(*(line+j) == '\r' || *(line+j) == '\n' || *(line+j) == 0);
        delete[] tuple;
        tuple = nullptr;
    }

}

template<typename Poker>
Strategy<Poker>::~Strategy(){

    int internal_size = Game<Poker_t>::num_internal;

    for(int i = 0; i<internal_size; ++i){

        if (tuples[i]){

            uint64_t isomorphism = Hand<Poker_t>::get_isomorphism_size(Game<Poker_t>::round[i], 0);

            for (uint64_t j = 0; j < isomorphism; ++j){
                delete[] tuples[i][j];
            }
            memset(tuples[i], 0, sizeof(double*) * isomorphism);
            delete[] tuples[i];
        }
    }
    memset(tuples, 0, sizeof(double**) * internal_size);
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
