import multiprocessing as mp
import time
from concurrent.futures import ProcessPoolExecutor

import fsspec
import hydra
import numpy as np
import pandas as pd
from algorithms import *
from games import *
from runner import runner


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg):
    start_time = time.time()
    np.random.seed(cfg.seed)

    # initialize game
    game_params = dict(cfg.game[cfg.game.game_name])
    game, nash_eq = eval(cfg.game.game_name)(**game_params)
    # initialize players
    players = []
    strategy_spaces = game.strategy_classes()
    for i in range(game.num_players()):
        if cfg.asymmetric_players:
            player_dict = eval("cfg.player{}".format(i))
        else:
            player_dict = eval("cfg.player{}".format(0))
        params = dict(player_dict)
        learning_alg = eval(player_dict.alg_name)
        players.append(learning_alg(strategy_spaces[i], **params))

    alg_name = (
        ",".join([players[i].name() for i in range(game.num_players())])
        if cfg.asymmetric_players
        else players[0].name()
    )
    game_param_str = "_".join([f"{key}{value}" for key, value in game_params.items()])
    save_path = "{}/{}/{}".format(
        cfg.output_dir, cfg.game.game_name + "_" + game_param_str, alg_name
    )

    # run experiments
    print("==========Run experiment==========")
    max_workers = int(mp.cpu_count()) - 1
    with ProcessPoolExecutor(max_workers=max_workers) as pool:
        arguments = [
            [trial_id, cfg, save_path, np.random.randint(0, 2**32)]
            for trial_id in range(0, cfg.n_trials)
        ]
        pool.map(run_experiment, *tuple(zip(*arguments)))

    nash_conv_dfs = []
    for trial_id in range(cfg.n_trials):
        nash_conv_file = "{}/results_{}.csv".format(save_path, trial_id)
        nash_conv_df = pd.read_csv(nash_conv_file, index_col=0)
        nash_conv_dfs.append(nash_conv_df)
    elapsed_time = time.time() - start_time
    print(save_path, " elapsed_time:{0}".format(elapsed_time) + "[sec]")


def run_experiment(trial_id, cfg, save_path, seed):
    print(f"==========Start trial {trial_id}==========")
    np.random.seed(seed)

    # initialize game
    game_params = dict(cfg.game[cfg.game.game_name])
    game, nash_eq = eval(cfg.game.game_name)(**game_params)

    # initialize players
    players = []
    strategy_spaces = game.strategy_classes()
    for i in range(game.num_players()):
        if cfg.asymmetric_players:
            player_dict = eval("cfg.player{}".format(i))
        else:
            player_dict = eval("cfg.player{}".format(0))
        params = dict(player_dict)
        learning_alg = eval(player_dict.alg_name)
        players.append(learning_alg(strategy_spaces[i], **params))

    log = runner.run(
        game, cfg.T, players, cfg.stop_early, cfg.print_interval, cfg.log_interval
    )

    # save log
    df = log.to_dataframe()
    df = df.set_index("t")
    print(f"==========Finish trial {trial_id}==========")
    save_file_path = f"{save_path}/results_{trial_id}.csv"
    with fsspec.open(save_file_path, "w") as f:
        df.to_csv(f)
    return df


if __name__ == "__main__":
    main()
