import argparse
import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms import mcts_agent
from open_spiel.python import rl_environment
import DQNAgents
import PGAgents
import PPOAgents
import os
import pyspiel
import csv
import json
import ast


def test_vs_MCTS(params, model_path=None, test_group_size=100, random_seed = 0, mcts_simulations = 1000):
    total_rewards_p1 = []
    total_rewards_p2 = []
    game = pyspiel.load_game(params['game'])
    params['num_actions'] = game.num_distinct_actions()
    params['state_dim'] = game.observation_tensor_size()
    env = rl_environment.Environment(game, include_full_state = True) # MCTS needs include_full_state = True
    evaluator = mcts.RandomRolloutEvaluator(n_rollouts=1, random_state=np.random.RandomState(random_seed))
    mcts_bot = mcts.MCTSBot(
            game=pyspiel.load_game(params['game']),
            uct_c=2.0,
            max_simulations=mcts_simulations,
            evaluator=evaluator,
            random_state=np.random.RandomState(random_seed),
        )
    match params['agent']:
        case 'LGAN_DQN':
            agent = DQNAgents.LGAN_DQNAgent(params)
        case 'Unmasked_DQN':
            agent = DQNAgents.Unmasked_DQNAgent(params)
        case 'Masked_DQN':
            agent = DQNAgents.Masked_DQNAgent(params)
        case 'LGAN_A2C':
            agent = PGAgents.LGAN_A2CAgent(params)
        case 'RES_DQN':
            agent = DQNAgents.ResNetDQNAgent(params)
        case 'LGAN_A2C':
            agent = PGAgents.LGAN_A2CAgent(params)
        case 'A2C':
            agent = PGAgents.A2CAgent(params)
        case 'CNN_A2C':
            agent = PGAgents.CNNA2CAgent(params)
        case 'RES_A2C':
            agent = PGAgents.ResNetA2CAgent(params)
        case 'PPO':
            agent = PPOAgents.PPOAgent(params)
        case _: 
            raise NotImplementedError
    if model_path:
        if "p1" in model_path:
            agent.load_model(model_path)
            mcts_oppo = mcts_agent.MCTSAgent(player_id=1, num_actions = game.num_distinct_actions(), mcts_bot = mcts_bot)
            agents = [agent, mcts_oppo]
        elif "p2" in model_path:
            agent.load_model(model_path)
            agent.player_id = 1
            mcts_oppo = mcts_agent.MCTSAgent(player_id=0, num_actions = game.num_distinct_actions(), mcts_bot = mcts_bot)
            agents = [mcts_oppo, agent]
    for _ in range(test_group_size):
        time_step = env.reset()
        while not time_step.last():
            agent_to_act = agents[time_step.observations["current_player"]]
            agent_output = agent_to_act.step(time_step, is_evaluation=True)
            time_step = env.step([agent_output.action])
        # Episode is over, step all agents with final info state.
        total_rewards_p1.append(time_step.rewards[0])
        total_rewards_p2.append(time_step.rewards[1])
    winrate_p1 = np.sum(np.array(total_rewards_p1) > 0) / test_group_size
    winrate_p2 = np.sum(np.array(total_rewards_p2) > 0) / test_group_size
    return [winrate_p1, winrate_p2]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model')
    parser.add_argument('--group_size', type=int)
    parser.add_argument('--seed')
    parser.add_argument('--mcts_simulations', type=int, default=1000)
    args = parser.parse_args()
    model_path = args.model + "/model.pth"
    train_log_file = os.path.dirname(args.model) + '/' + "train_log.txt"
    with open(train_log_file, 'r', encoding='utf-8') as f:
        f.readline()
        second_line = f.readline()
        params = ast.literal_eval(second_line.strip())

    result = test_vs_MCTS(params=params, model_path=model_path, test_group_size=args.group_size, random_seed = int(args.seed), mcts_simulations = args.mcts_simulations)
    row = {'p1_winrate': result[0], 'p2_winrate': result[1]}

    output_file = args.model + "/result.csv"
    write_header = not os.path.exists(output_file)
    with open(output_file, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=row.keys())
        if write_header:
            writer.writeheader()
        writer.writerow(row)



