#include <string>
#include <iostream>
#include <fstream>
#include <cstdlib>
#include "open_spiel/abseil-cpp/absl/flags/flag.h"
#include "open_spiel/abseil-cpp/absl/flags/parse.h"
#include "open_spiel/algorithms/tabular_exploitability.h"
#include "open_spiel/algorithms/best_response.h"
#include "open_spiel/algorithms/expected_returns.h"
#include "open_spiel/policy.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"
#include "full/algorithm_base.hpp"
#include "full/cfr.hpp"
#include "full/pcfr.hpp"
#include "full/dcfr.hpp"
#include "full/pcfr.hpp"
#include "logger.hpp"
#include "utils.hpp"

ABSL_FLAG(std::string, logdir, "logs", "Directory to log to");
ABSL_FLAG(int, num_iters, 1000, "How many iters to run for.");
ABSL_FLAG(int, seed, 0, "Seed for random number generation");
ABSL_FLAG(std::string, game_name, "kuhn_poker", "Game to run CFR on.");
ABSL_FLAG(std::string, regularizer, "entropy", "Regularization function.");
ABSL_FLAG(std::string, perturbation_regularizer, "kl","Regularization function.");
ABSL_FLAG(double, learning_rate, 0.1, "Learning rate.");
ABSL_FLAG(double, mutation_rate, 0.01, "Mutation rate");
ABSL_FLAG(int, update_anchoring_interval, 0, "How often to update anchoring strategies.");
ABSL_FLAG(double, alpha, 1.5, "Mutation rate");
ABSL_FLAG(double, beta, 0.0, "Mutation rate");
ABSL_FLAG(double, gamma, 2.0, "Mutation rate");
ABSL_FLAG(int, report_every, 100, "How often to report exploitability.");
ABSL_FLAG(int, asym_player_id, -1, "Asymmetric player id.");
ABSL_FLAG(std::string, alg_name, "cfr", "Algorithm to run");

std::vector<double> get_game_values(const std::string& game_name) {
    std::vector<double> game_values;
    if (game_name == "kuhn_poker") {
        game_values = {-1.0 / 18.0, 1.0 / 18.0};
    } else if (game_name == "leduc_poker") {
        game_values = {-0.08560642, 0.08560642};
    } else if (game_name == "leduc_poker(suit_isomorphism=True)") {
        game_values = {-0.08560642, 0.08560642};
    } else if (game_name == "liars_dice(dice_sides=4)") {
        game_values = {1.0 / 16.0, -1.0 / 16.0};
    } else if (game_name == "liars_dice(dice_sides=6)") {
        game_values = {-0.0271218, 0.0271218};
    } else if (game_name == "turn_based_simultaneous_game(game=goofspiel(imp_info=True,num_cards=4,points_order=descending))") {
        game_values = {0.0, 0.0};
    } else if (game_name == "turn_based_simultaneous_game(game=goofspiel(imp_info=True,num_cards=5,points_order=descending))") {
        game_values = {0.0, 0.0};
    } else {
        game_values = {0.0, 0.0};
    }
    std::cerr << "Game values of " << game_name << ": "<< game_values[0] << " " << game_values[1] << std::endl;
    return game_values;
}

// choose algorithm
std::shared_ptr<open_spiel::algorithms::AlgorithmBase> create_solver(std::string alg_name, std::shared_ptr<const open_spiel::Game> game) {
    if (alg_name == "cfr")
        return std::make_shared<open_spiel::algorithms::MyCFRSolver>(*game);
    else if (alg_name == "cfr_plus")
        return std::make_shared<open_spiel::algorithms::MyCFRPlusSolver>(*game);
    else if (alg_name == "dcfr")
        return std::make_shared<open_spiel::algorithms::DCFRSolver>(
            *game,
            /*alpha=*/absl::GetFlag(FLAGS_alpha),
            /*beta=*/absl::GetFlag(FLAGS_beta),
            /*gamma=*/absl::GetFlag(FLAGS_gamma)
        );
    else if (alg_name == "lcfr"){
        return std::make_shared<open_spiel::algorithms::LCFRSolver>(*game);
    }
    else if (alg_name == "pcfr_plus")
        return std::make_shared<open_spiel::algorithms::PCFRPlusSolver>(*game,
                /*mutation_rate=*/absl::GetFlag(FLAGS_mutation_rate),
                /*update_anchoring_interval=*/absl::GetFlag(FLAGS_update_anchoring_interval),
                /*asym_player_id=*/absl::GetFlag(FLAGS_asym_player_id)
        );
    else
        throw std::invalid_argument("Invalid algorithm name");
}

std::vector<double> compute_best_response_values(const open_spiel::Game& game, const open_spiel::Policy& policy) {
    std::unique_ptr<open_spiel::State> root = game.NewInitialState();
    std::vector<double> best_response_values(game.NumPlayers());
    for (auto p = open_spiel::Player{0}; p < game.NumPlayers(); ++p) {
        open_spiel::algorithms::TabularBestResponse best_response(game, p, &policy);
        best_response_values[p] = best_response.Value(*root);
    }
    return best_response_values;
}

// Example code for using CFR+ to solve Kuhn Poker.
int main(int argc, char **argv) {
    absl::ParseCommandLine(argc, argv);
    std::cerr << "Starting run_full" << std::endl;
    std::shared_ptr<const open_spiel::Game> game = open_spiel::LoadGame(absl::GetFlag(FLAGS_game_name));
    // get game value
    std::vector<double> game_values = get_game_values(absl::GetFlag(FLAGS_game_name));
    std::cerr << "Loaded game " << game->GetType().long_name << std::endl;
    std::shared_ptr<open_spiel::algorithms::AlgorithmBase> solver = create_solver(absl::GetFlag(FLAGS_alg_name), game);
    std::string solver_name;
    solver_name = solver->GetName();
    std::cerr << "Created solver " << solver_name << std::endl;
    std::cerr << "Starting Full " << absl::GetFlag(FLAGS_alg_name) << " on " << game->GetType().long_name << "..." << std::endl;
    std::string logdir = absl::GetFlag(FLAGS_logdir) + "/cpp/" + absl::GetFlag(FLAGS_game_name) + "/";
    logdir += solver_name;
    MakeDir(logdir);
    std::string current_filename = logdir + "/current_exploitability" + "_seed_" + std::to_string(absl::GetFlag(FLAGS_seed)) + "_num_iters_" + std::to_string(absl::GetFlag(FLAGS_num_iters)) + ".csv";
    std::string current_header = ",0";
    std::cerr << "Logging to " << current_filename << std::endl;
    Logger current_logger(current_filename, current_header);
    std::string current_filename_0 = logdir + "/current_exploitability_0" + "_seed_" + std::to_string(absl::GetFlag(FLAGS_seed)) + "_num_iters_" + std::to_string(absl::GetFlag(FLAGS_num_iters)) + ".csv";
    std::string current_header_0 = ",0";
    std::string average_filename = logdir + "/average_exploitability" + "_seed_" + std::to_string(absl::GetFlag(FLAGS_seed)) + "_num_iters_" + std::to_string(absl::GetFlag(FLAGS_num_iters)) + ".csv";
    std::string average_header = ",0";
    Logger average_logger(average_filename, average_header);
    std::cerr << "Logging to " << current_filename_0 << std::endl;
    Logger current_logger_0(current_filename_0, current_header_0);
    std::string current_filename_1 = logdir + "/current_exploitability_1" + "_seed_" + std::to_string(absl::GetFlag(FLAGS_seed)) + "_num_iters_" + std::to_string(absl::GetFlag(FLAGS_num_iters)) + ".csv";
    std::string current_header_1 = ",0";
    std::cerr << "Logging to " << current_filename_1 << std::endl;
    Logger current_logger_1(current_filename_1, current_header_1);
    for (int i = 0; i < absl::GetFlag(FLAGS_num_iters) + 1; ++i) {
        solver->EvaluateAndUpdatePolicy();
        if (i % absl::GetFlag(FLAGS_report_every) == 0 ||
            i == absl::GetFlag(FLAGS_num_iters) || i < absl::GetFlag(FLAGS_report_every)) {
            double exploit_average = open_spiel::algorithms::Exploitability(*game, *solver->AveragePolicy());
            double exploit_current = open_spiel::algorithms::Exploitability(*game, *solver->CurrentPolicy());
            double log10_exploit_current = std::log10(exploit_current);
            double log10_exploit_average = std::log10(exploit_average);
            current_logger.log(std::to_string(i) + "," + std::to_string(log10_exploit_current));
            average_logger.log(std::to_string(i) + "," + std::to_string(log10_exploit_average));
            std::cerr << "Iteration " << i << " exploit_average=" << exploit_average
                      << " exploit_current=" << exploit_current << std::endl;
            std::vector<double> best_response_values = compute_best_response_values(*game, *solver->CurrentPolicy());
            double log10_exploit_current_0 = std::log10(best_response_values[1] - game_values[1]);
            double log10_exploit_current_1 = std::log10(best_response_values[0] - game_values[0]);
            current_logger_0.log(std::to_string(i) + "," + std::to_string(log10_exploit_current_0));
            current_logger_1.log(std::to_string(i) + "," + std::to_string(log10_exploit_current_1));
            std::cerr << "Iteration " << i
                      << " exploit_current_0=" << best_response_values[1] - game_values[1]
                      << " exploit_current_1=" << best_response_values[0] - game_values[0] << std::endl;
        }
    }
    std::vector<double> expected_returns = open_spiel::algorithms::ExpectedReturns(
        *game->NewInitialState(), *solver->CurrentPolicy(),
        /*depth_limit*/-1);
    std::cerr << "Expected returns: " << expected_returns[0] << " " << expected_returns[1] << std::endl;
    Logger policy_logger(logdir + "/policy.txt", "");
    std::string serialize_policy = (*solver->CurrentTabularPolicy()).Serialize();
    policy_logger.log(serialize_policy);
    auto deserialized_policy = open_spiel::DeserializeTabularPolicy(serialize_policy);
    std::cerr << open_spiel::algorithms::Exploitability(*game, *deserialized_policy) << std::endl;
    current_logger.flush();
    current_logger_0.flush();
    current_logger_1.flush();
    policy_logger.flush();
}