from pickle import NONE
import time

import pandas as pd
import pyspiel
import numpy as np
import random
import torch
import DQNAgents
import PGAgents
import PPOAgents
import yaml
from experiments import *
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms import mcts_agent
from open_spiel.python.algorithms import random_agent
from open_spiel.python import rl_environment
from datetime import datetime
import argparse
import os
from multiprocessing import Pool
import subprocess

def default(params, random_seed=0, folder=None):
    # self-play but two models for two players
    game = pyspiel.load_game(params['game'])
    params['num_actions'] = game.num_distinct_actions()
    params['state_dim'] = game.observation_tensor_size()
    match params['agent']:
        case 'LGAN_DQN':
            p1 = DQNAgents.LGAN_DQNAgent(params)
            p2 = DQNAgents.LGAN_DQNAgent(params)
        case 'Unmasked_DQN':
            p1 = DQNAgents.Unmasked_DQNAgent(params)
            p2 = DQNAgents.Unmasked_DQNAgent(params)
        case 'Masked_DQN':
            p1 = DQNAgents.Masked_DQNAgent(params)
            p2 = DQNAgents.Masked_DQNAgent(params)
        case 'RES_DQN':
            p1 = DQNAgents.ResNetDQNAgent(params)
            p2 = DQNAgents.ResNetDQNAgent(params)
        case 'LGAN_A2C':
            p1 = PGAgents.LGAN_A2CAgent(params)
            p2 = PGAgents.LGAN_A2CAgent(params)
        case 'A2C':
            p1 = PGAgents.A2CAgent(params)
            p2 = PGAgents.A2CAgent(params)
        case 'CNN_A2C':
            p1 = PGAgents.CNNA2CAgent(params)
            p2 = PGAgents.CNNA2CAgent(params)
        case 'RES_A2C':
            p1 = PGAgents.ResNetA2CAgent(params)
            p2 = PGAgents.ResNetA2CAgent(params)
        case 'PPO':
            p1 = PPOAgents.PPOAgent(params)
            p2 = PPOAgents.PPOAgent(params)
        case 'LGAN_PPO':
            p1 = PPOAgents.LGAN_PPOAgent(params)
            p2 = PPOAgents.LGAN_PPOAgent(params)
        case _: 
            raise NotImplementedError
    p1.player_id = 0
    p2.player_id = 1
    if 'p1_model' in params.keys():
        p1.load_model(params['p1_model'])
        p1.epsilon_start = 0.1
    if 'p2_model' in params.keys():
        p2.load_model(params['p2_model'])
        p2.epsilon_start = 0.1
    agents = [p1, p2]
    env = rl_environment.Environment(game, include_full_state = True) # MCTS needs include_full_state = True
    train_group_size = params['train_group_size']
    best_res_p1 = -np.inf
    best_res_p2 = -np.inf
    time_stamp_str = datetime.now().strftime("%d_%H_%M_%S")
    patience_counter = 0
    if folder is None:
        folder = time_stamp_str + "/"
    for i in range(30):
        flag = False
        simulate(agents, env, group_size=train_group_size, num_groups=5)

        model_path = folder + params['game'] + "_p1_round_%d" % i
        save_information_breakthrough(agents[0], model_path)
        processes = []
        for j in range(20):
            COMMAND = ["python", "-u", "MCTS_test.py", "--model", "outputs/" + model_path,
                        "--group_size", "50", "--seed", str(j)]
            processes.append(subprocess.Popen(COMMAND, stdout=subprocess.PIPE))
        for process in processes:
            process.wait()
        csv_file = "outputs/" + model_path + "/result.csv"
        if not os.path.exists(csv_file):
            raise FileNotFoundError 
        df = pd.read_csv(csv_file)
        avg_p1_winrate = df['p1_winrate'].mean() * 100
        if best_res_p1 < avg_p1_winrate:
            best_res_p1 = avg_p1_winrate
            flag = True

        model_path = folder + params['game'] + "_p2_round_%d" % i
        save_information_breakthrough(agents[1], model_path)
        processes = []
        for j in range(20):
            COMMAND = ["python", "-u", "MCTS_test.py", "--model", "outputs/" + model_path,
                        "--group_size", "50", "--seed", str(j)]
            processes.append(subprocess.Popen(COMMAND, stdout=subprocess.PIPE))
        for process in processes:
            process.wait()
        csv_file = "outputs/" + model_path + "/result.csv"
        if not os.path.exists(csv_file):
            raise FileNotFoundError 
        df = pd.read_csv(csv_file)
        avg_p2_winrate = df['p2_winrate'].mean() * 100
        if best_res_p2 < avg_p2_winrate:
            best_res_p2 = avg_p2_winrate
            flag = True
        
        print("Avg_p1_winrate: %.1f" %avg_p1_winrate+ "   Avg_p2_winrate: %.1f" %avg_p2_winrate)
        if best_res_p1 == -np.inf and best_res_p2 == -np.inf and i > 5:
            print("Model is not learning, check setting.")
        if not flag and max(best_res_p1, best_res_p2) > 0:
            patience_counter += 1
        if patience_counter > params.get('patience_counter', 0):
            print("No improvement, stopping training.")
            print("Best winrate for p1: %.2f, p2: %.2f" % (best_res_p1, best_res_p2))
            break

    return [best_res_p1, best_res_p2]

def single_model(params, random_seed=0, folder=None):
    # OUTDATED NOT IN USE
    # self-play but one model for two players
    game = pyspiel.load_game(params['game'])
    params['num_actions'] = game.num_distinct_actions()
    params['state_dim'] = game.observation_tensor_size()
    match params['agent']:
        case 'LGAN_DQN':
            player = DQNAgents.LGAN_DQNAgent(params)
        case 'Unmasked_DQN':
            player = DQNAgents.Unmasked_DQNAgent(params)
        case 'Masked_DQN':
            player = DQNAgents.Masked_DQNAgent(params)
        case 'RES_DQN':
            player = DQNAgents.ResNetDQNAgent(params)
        case 'LGAN_A2C':
            player = PGAgents.LGAN_A2CAgent(params)
        case 'A2C':
            player = PGAgents.A2CAgent(params)
        case 'CNN_A2C':
            player = PGAgents.CNNA2CAgent(params)
        case 'RES_A2C':
            player = PGAgents.ResNetA2CAgent(params)
        case 'PPO':
            player = PPOAgents.PPOAgent(params)
        case 'LGAN_PPO':
            player = PPOAgents.LGAN_PPOAgent(params)
        case _: 
            raise NotImplementedError
    if 'player_model' in params.keys():
        player.load_model(params['player_model'])
        player.epsilon_start = 0.1
    evaluator = mcts.RandomRolloutEvaluator(n_rollouts=1, random_state=np.random.RandomState(random_seed))
    mcts_bot = mcts.MCTSBot(
            game=pyspiel.load_game(params['game']),
            uct_c=2.0,
            max_simulations=1000,
            evaluator=evaluator,
            random_state=np.random.RandomState(random_seed),
        )
    p1_mcts = mcts_agent.MCTSAgent(player_id=0, num_actions = game.num_distinct_actions(), mcts_bot = mcts_bot)
    p2_mcts = mcts_agent.MCTSAgent(player_id=1, num_actions = game.num_distinct_actions(), mcts_bot = mcts_bot)
    env = rl_environment.Environment(game, include_full_state = True) # MCTS needs include_full_state = True
    train_group_size = params['train_group_size']
    test_group_size = 100
    best_res_p1 = -np.inf
    best_res_p2 = -np.inf
    time_stamp_str = datetime.now().strftime("%d_%H_%M_%S")
    time_stamp_str += "/"
    patience_counter = 0
    for i in range(20):
        flag = False
        self_play(player, env, group_size=train_group_size, num_groups=5)
        [winrate_p1, _] = test_p1_vs_p2([player, p2_mcts], env, test_group_size=test_group_size, verbose=True)
        if best_res_p1 < winrate_p1:
            best_res_p1 = winrate_p1
            flag = True
        [_, winrate_p2] = test_p1_vs_p2([p1_mcts, player], env, test_group_size=test_group_size, verbose=True)
        flag_str = "Breakthrough_p1_wincount%03d" % int(winrate_p1 * test_group_size) + "_p2_wincount%03d" % int(winrate_p2 * test_group_size)
        if best_res_p2 < winrate_p2:
            best_res_p2 = winrate_p2
            flag = True
        save_information_breakthrough(player, time_stamp_str + flag_str)
        if not flag and max(best_res_p1, best_res_p2) > 0:
            patience_counter += 1
        if patience_counter > params.get('patience_counter', 0):
            print("No improvement, stopping training.")
            print("Best winrate for p1: %.2f, p2: %.2f" % (best_res_p1, best_res_p2))
            break
    return [best_res_p1, best_res_p2]

def ttt(params, random_seed=0, folder=None):
    # self-play but output ttt information
    game = pyspiel.load_game(params['game'])
    params['num_actions'] = game.num_distinct_actions()
    params['state_dim'] = game.observation_tensor_size()
    match params['agent']:
        case 'LGAN_DQN':
            p1 = DQNAgents.LGAN_DQNAgent(params)
            p2 = DQNAgents.LGAN_DQNAgent(params)
        case 'LGAN_A2C':
            p1 = PGAgents.LGAN_A2CAgent(params)
            p2 = PGAgents.LGAN_A2CAgent(params)
        case _: 
            raise NotImplementedError
    
    p1.player_id = 0
    p2.player_id = 1
    if 'p1_model' in params.keys():
        p1.load_model(params['p1_model'])
        p1.epsilon_start = 0.1
    p2_random = random_agent.RandomAgent(player_id=1, num_actions = game.num_distinct_actions())
    agents = [p1, p2]
    env = rl_environment.Environment(game, include_full_state = True) # MCTS needs include_full_state = True
    train_group_size = params['train_group_size']
    test_group_size = 1000
    best_res_p1 = -np.inf
    time_stamp_str = datetime.now().strftime("%d_%H_%M_%S")
    if folder is None:
        folder = time_stamp_str + "/"
    for i in range(30):
        simulate(agents, env, group_size=train_group_size, num_groups=5)
        [winrate_p1, winrate_p2] = test_p1_vs_p2([p1, p2_random], env, test_group_size=test_group_size, verbose=True)
        if best_res_p1 < winrate_p1 - winrate_p2:
            best_res_p1 = winrate_p1 - winrate_p2
        save_information_TTT(p1, folder + str(i) + "_%.1f" % ((winrate_p1 - winrate_p2) * 100) + "/")
    return [best_res_p1, None]

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run self-play experiments.")
    parser.add_argument(
        "--experiment_file",
        type=str,
        default="experiment_setting.yaml",
        help="Path to the experiment setting YAML file."
    )
    parser.add_argument(
        "--name",
        type=str,
        default="ttt_LGAN",
        help="Experiment name."
    )
    parser.add_argument(
        "--seed",
        type=str,
        default="0",
        help="Experiment name."
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="default",
        help="Experiment name."
    )
    args = parser.parse_args()
    
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    random.seed(int(args.seed))
    np.random.seed(int(args.seed))
    torch.manual_seed(int(args.seed))
    if torch.cuda.is_available():
        torch.cuda.manual_seed(int(args.seed))
        torch.cuda.manual_seed_all(int(args.seed))
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

    result_list = []
    with open(args.experiment_file, 'r') as f:
        experiments_setting = yaml.safe_load(f)

    print("Experiment setting:")
    print(experiments_setting['self_play_experiments'][args.name])
    print("=" * 80)
    match args.mode:
        case "default":
            default(experiments_setting['self_play_experiments'][args.name], folder=args.name+"_seed_"+args.seed+"/" , random_seed=int(args.seed))

        case "single_model":
            raise NotImplementedError
            single_model(experiments_setting['self_play_experiments'][args.name])
        case "ttt":
            ttt(experiments_setting['self_play_experiments'][args.name], folder=args.name+"_seed_"+args.seed+"/", random_seed=int(args.seed))
        case _:
            raise NotImplementedError
    