// 自己组合1096680 双组合690908400
#include "numeral211/game_numeral211.hpp"
#include "numeral211/hand_numeral211.hpp"
#include "numeral211/strategy_numeral211.hpp"
#include "numeral211/regret_strategy_numeral211.hpp"
#include "numeral211/evaluator_numeral211.hpp"
#include "template/algorithm/recursive/multithread_cs_mccfr.hpp"
#include "template/algorithm/iteration/multithread_best_response.hpp"
#include <boost/program_options.hpp>
#include <iostream>
// #include "template/domain_independent/abstraction.h"
#include "template/hand_work/abstr_hits.hpp"
#include "template/hand_work/full_resolution_abstr.hpp"
#include "template/hand_work/xriso_abstr.hpp"
#include "template/hand_work/suite_abstr.hpp"
#include <filesystem>

#include <boost/log/core.hpp>
#include <boost/log/trivial.hpp>
#include <boost/log/expressions.hpp>

#include "template/domain_independent/multithread_player_terminal_reach.hpp"

namespace po = boost::program_options;
using Poker_t = Numeral211;

// nohup stdbuf -o0 ./build/numeral211_train_with_br_pwi --epoch_or_time epoch --sample_batch 1096680 --interval_linear_or_power power --start 73 --end 50209 --logs 18 &


// 2.14 
// tail -fn 30 ./data/numeral211/remote/multithread_train_with_br_sb1096680_ep_73_50209_l18_pwi.log
// tail -fn 30 ./data/numeral211/remote/multithread_train_with_br_sb1096680_ep_73_50209_l18_pwi.csv

int main(int argc, char **argv){ 
#ifndef DEBUG
    boost::log::core::get()->set_filter (boost::log::trivial::severity >= boost::log::trivial::info);
#endif
    std::filesystem::create_directories("data/numeral211/remote/");
    std::string program_name = "data/numeral211/remote/multithread_train_with_br";
    bool b_epoch = false;
    bool b_time = false;
    bool b_linear = false;
    bool b_power = false;
    po::options_description desc("Allowed options");
    desc.add_options()
        ("help", "produce help message")
        ("sample_batch", po::value<uint64_t>())
        ("epoch_or_time", po::value<std::string>())
        ("interval_linear_or_power", po::value<std::string>())
        ("start", po::value<int32_t>())
        ("end", po::value<int32_t>())
        ("logs", po::value<uint64_t>())
    ;

    po::variables_map vm;        
    po::store(po::parse_command_line(argc, argv, desc), vm);
    po::notify(vm);

    uint64_t sample_batch;
    uint64_t logs;
    int32_t et_start;
    int32_t et_end;

    if (vm.count("help")) {
        std::cout << desc << "\n";
        return 0;
    }

    if (vm.count("sample_batch")){
        sample_batch = vm["sample_batch"].as<uint64_t>();
        program_name += "_sb" + std::to_string(sample_batch);
    }
    else
        return 0;

    if (vm.count("epoch_or_time") && vm["epoch_or_time"].as< std::string>() == "epoch") {
        b_epoch = true;
        program_name += "_e";
    }
    else if (vm.count("epoch_or_time") && vm["epoch_or_time"].as< std::string >() == "time") {
        b_time = true;
        program_name += "_t";
    }
    else {
        return 0;
    }

    if (vm.count("interval_linear_or_power") && vm["interval_linear_or_power"].as< std::string>() == "linear") {
        b_linear = true;
        program_name += "l";
    }
    else if (vm.count("interval_linear_or_power") && vm["interval_linear_or_power"].as< std::string >() == "power") {
        b_power = true;
        program_name += "p";
    }
    else {
        return 0;
    }

    if (vm.count("start")){
        et_start = vm["start"].as<int32_t>();
        program_name += "_" + std::to_string(et_start);
    }
    else
        return 0;

    if (vm.count("end")){
        et_end = vm["end"].as<int32_t>();
        program_name += "_" + std::to_string(et_end);
    }
    else
        return 0;

    if (vm.count("logs")){
        logs = vm["logs"].as<uint64_t>();
        program_name += "_l";
        program_name += std::to_string(logs);
        assert(logs > 1);
    }
    else
        return 0;

    /* 为全局随机数生成器赋种子 */
    overall_define::mt_rand.seed(0);

    /* 初始化各种与博弈游戏相关的常数 */
    Game<Poker_t>::init();
    Hand<Poker_t>::init();
    Evaluator<Poker_t>::init();
    printf("num_internal: %d\n", Game<Poker_t>::num_internal);
    printf("num_chance: %d\n", Game<Poker_t>::num_chance);
    printf("num_terminal: %d\n", Game<Poker_t>::num_terminal);
    printf("num_total: %d\n", Game<Poker_t>::num_total);

{
    assert(Game<Poker_t>::initialized);
    assert(Hand<Poker_t>::initialized);
    assert(Evaluator<Poker_t>::initialized);


    /* 为博弈参与人创建抽象 */
    printf("creating abstractions...\n");

    // 2.4._pwi
    Abstraction<Poker_t> train0_abstr(FullResolutionAbstr<Poker_t>({ResolutionConfig{WinningTrace, 0}, ResolutionConfig{WinningTrace, 1}, ResolutionConfig{WinningTrace, 2}}, "Numeral211"));
    Abstraction<Poker_t> train1_abstr(FullResolutionAbstr<Poker_t>({ResolutionConfig{WinningTrace, 0}, ResolutionConfig{WinningTrace, 1}, ResolutionConfig{WinningTrace, 2}}, "Numeral211"));
    Abstraction<Poker_t> exploiter_abstr(XrisoAbstr<Poker_t>({0, 0, 0}));
    program_name += "_pwi";

    Abstraction<Poker_t> *train_abstr[]{&train0_abstr, &train1_abstr};
    PrAbstrHits<Poker_t> brabstrhits;
    brabstrhits.perfect_recall_config(exploiter_abstr);

    /* open file*/
    FILE * log_file = fopen((program_name+".log").c_str(), "wt");
    assert(log_file);
    FILE * csv_file = fopen((program_name+".csv").c_str(), "wt");
    assert(log_file);

    fprintf(log_file, "num_internal: %d\n", Game<Poker_t>::num_internal);
    fprintf(log_file, "num_chance: %d\n", Game<Poker_t>::num_chance);
    fprintf(log_file, "num_terminal: %d\n", Game<Poker_t>::num_terminal);
    fprintf(log_file, "num_total: %d\n", Game<Poker_t>::num_total);


    fprintf(log_file, "train0 abstraction:\n");
    train_abstr[0]->print(log_file);
    fprintf(log_file, "train1 abstraction:\n");
    train_abstr[1]->print(log_file);
    fprintf(log_file, "br abstraction:\n");
    exploiter_abstr.print(log_file);
    fprintf(log_file, "br abstrhits:\n");
    brabstrhits.print(log_file);

    /* initialize memory  */
    fprintf(log_file, "initializing strategies...\n");
    std::shared_ptr<RegretStrategy<Poker_t, true>> sp_regret_strat[2] = {
        std::make_shared<RegretStrategy<Poker_t, true>>( 0
                                                 , train_abstr[0]->bucket_sizes()),
        std::make_shared<RegretStrategy<Poker_t, true>>( 1
                                                 , train_abstr[1]->bucket_sizes()),
    };

    // train variable
    MultithreadCS_MCCFR<Poker_t> m_cs_mccfr(sample_batch, overall_define::num_threads, sp_regret_strat);
    double acfr[Poker_t::num_players];
    std::memset(acfr, 0, sizeof(double)*Poker_t::num_players);
    uint64_t (*buckets)[Poker_t::num_players][Poker_t::num_rounds];
    type::rank_t (*ranks)[Poker_t::num_players];
    buckets = new uint64_t[sample_batch][Poker_t::num_players][Poker_t::num_rounds];
    ranks = new type::rank_t[sample_batch][Poker_t::num_players];

    // br varialbe
    MultithreadPlayerTerminalReach<Poker_t> m_player_terminal_reaches[Poker_t::num_players] {
        MultithreadPlayerTerminalReach<Poker_t>(overall_define::num_threads),
        MultithreadPlayerTerminalReach<Poker_t>(overall_define::num_threads),
    };
    std::function<double(int, uint64_t)> get_reach_funcs[Poker_t::num_players]{
        std::bind(&MultithreadPlayerTerminalReach<Poker_t>::get_reach, &m_player_terminal_reaches[0], std::placeholders::_1, std::placeholders::_2), 
        std::bind(&MultithreadPlayerTerminalReach<Poker_t>::get_reach, &m_player_terminal_reaches[1], std::placeholders::_1, std::placeholders::_2)
    };
    Strategy<Poker_t> strategy[Poker_t::num_players];
    MultithreadBestResponse<Poker_t> best_response(overall_define::num_threads);

    time_t time_start_stamp;
    time_start_stamp = time(NULL);
    double time_upto = difftime( time(NULL), time_start_stamp);
    fprintf(log_file, "train start: %s\n", ctime(&time_start_stamp));

    // fprintf(csv_file, "updates, upto-updates-time, p1acfr, p2acfr, p1br, p2br, upto-br-time\n");
    uint64_t epk = 0;
    for(uint64_t l = 0; l<logs; ++l) {

        do {
            ++epk;
            for (uint64_t b = 0; b < sample_batch; ++b) {
                /* deal hand */
                type::card_t holes[Poker_t::num_players][Poker_t::hole_len[Poker_t::num_rounds-1]], board[Poker_t::board_len[Poker_t::num_rounds-1]];
                deal_all_hand<Poker_t>(holes, board);

                std::vector<Hand<Poker_t>> hand;
                hand.reserve(Poker_t::num_players);
                /*construct hand and abstract the hand*/
                for (int j = 0; j < Poker_t::num_players; ++j){
                    hand.emplace_back(holes[j], board, 0);
                    for(int r = 0; r < Poker_t::num_rounds; ++r){
                        if(r>0)
                            hand[j].add_deal(holes[j]+Poker_t::hole_len[r-1], board+Poker_t::board_len[r-1]);
                        buckets[b][j][r] = train_abstr[j]->abstract_view(hand[j]);
                    }
                }

                // evaluate ranks
                Evaluator<Poker_t>::evaluate_ranks(holes, board, ranks[b]);
            }

            //updata regret
            double reach[Poker_t::num_players];
            std::fill(reach, reach+Poker_t::num_players, 1.);
            double ev[Poker_t::num_players], cfr[Poker_t::num_players];

            m_cs_mccfr.train( buckets
                            , ranks
                            , reach
                            , ev
                            , cfr
                            );

            /* update average CFR*/
            if (epk == 1){
                std::memcpy(acfr, cfr, sizeof(double)*Poker_t::num_players);
            }
            else {
                for (int j=0; j<Poker_t::num_players; ++j){
                    acfr[j] = 1. * (epk-1)/epk * (acfr[j] + cfr[j]/(epk-1));
                }
            }
            time_upto = difftime( time(NULL), time_start_stamp);
        } while ( (b_epoch && b_linear && (epk-et_start) * (logs-1) < (et_end-et_start) * l) 
               || (b_time && b_linear && (time_upto-et_start) * (logs-1) < (et_end-et_start) * l)
               || (b_epoch && b_power && log(1.0*epk/et_start) * (logs-1) < log(1.0*et_end/et_start) * l)
               || (b_time && b_power && log(1.0*time_upto/et_start) * (logs-1) < log(1.0*et_end/et_start) * l));

        fprintf(csv_file, "updates, %lu, time, %.1fs, ", epk * sample_batch, difftime( time(NULL), time_start_stamp));
        
        // acfr
        for(int p = 0;  p < Poker_t::num_players; ++p) {
            fprintf(csv_file, "acfr%d, %.2lf, ", p, acfr[p]);
        }
        fflush(csv_file);

        // #2 exploit each other 
        for(int p = 0; p < Poker_t::num_players; ++p) {
            m_player_terminal_reaches[p].compute_reach(*sp_regret_strat[p], *train_abstr[p]);
            // assert(assert_reach(*p_player_terminal_reachs[p], m_player_terminal_reaches[p]));
        }
        ///////////////////

        fprintf(csv_file, "time, %.1fs\n", difftime( time(NULL), time_start_stamp));
        fflush(csv_file);
    }

    delete[] buckets;
    delete[] ranks;

    time_t time_end_stamp = time(NULL);
    printf("train end: %s\n", ctime(&time_end_stamp));

    printf("saving strategies\n");
    std::string output_str[Poker_t::num_players]{
        "data/numeral211/1numeral211",
        "data/numeral211/2numeral211",
    };
    for (int i = 0; i < Poker_t::num_players; ++i){

        printf("saving player %d to %s...\n", i+1, output_str[i].c_str());

        /* open file*/
        FILE * file = fopen(output_str[i].c_str(), "wt");
        assert(file);

        /* write the header */
        time_t time_start_stamp = time(NULL);
        train_abstr[i]->print(file);
        fprintf(file, 
            "#\n"
            "# numeral211 strategy %s\n"
            "# made on:       %s"
            "# epochs:    %lu\n"
            "#\n",
            output_str[i].c_str(),
            ctime(&time_start_stamp),
            epk);

        /* recover the strategy */
        Strategy<Poker_t> strat;
        strat.recover_average_strategy(*train_abstr[i], *sp_regret_strat[i]);

        // /* write the strategy */
        // strat.save_txt(file);

        /* close the file */
        fclose(file);
    }
    fprintf(log_file, "end:  %fs", difftime( time(NULL), time_start_stamp));

    fclose(csv_file);
    fclose(log_file);
}
    Evaluator<Poker_t>::free();
    Hand<Poker_t>::free();
    Game<Poker_t>::free();
}

// #include <yaml-cpp/yaml.h>
// #include <iostream>
// #include <fstream>
// #include <vector>

// int main() {
//     try {
//         YAML::Node config = YAML::LoadFile("alg_configs.yaml");
//         for (size_t street = 0; street < config["alg_configs"].size(); ++street) {
//             YAML::Node alg = config["alg_configs"][street];
//             std::cout << "Index: " << street << std::endl;
//             std::cout << "Type: " << alg["type"].as<std::string>() << std::endl;
//             std::cout << "Recall from: " << alg["recall_from"].as<size_t>() << std::endl;
//         }
//     } catch (const std::exception& e) {
//         std::cerr << "Exception caught: " << e.what() << std::endl;
//     }

//     return 0;
// }