import pyspiel
import numpy as np
import random
import torch
import DQNAgents
import PGAgents
import yaml
from experiments import *
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms import mcts_agent
from open_spiel.python import rl_environment
from datetime import datetime

def round_robin_go():
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    agent_names = ['Go_LGAN_DQN_24_mean', 'Go_LGAN_A2C_24_mean', 'Go_RES_DQN', 'Go_RES_A2C']
    agent_list_p1 = []
    agent_list_p2 = []
    experiment_setting_file = "experiment_setting.yaml"
    with open(experiment_setting_file, 'r') as f:
        experiments_setting = yaml.safe_load(f)
    game = pyspiel.load_game('go(board_size=7,komi=4.5)')
    env = rl_environment.Environment(game) 

    for agent_name in agent_names:
        params = experiments_setting['self_play_experiments'][agent_name]
        params['num_actions'] = game.num_distinct_actions()
        params['state_dim'] = game.observation_tensor_size()
        match params['agent']:
            case 'LGAN_DQN':
                p1 = DQNAgents.LGAN_DQNAgent(params)
                p1.load_model("outputs/Go_LGAN_DQN_24_mean_3D_seed_0/go(board_size=7,komi=4.5)_p1_round_3/model.pth")
                agent_list_p1.append(p1)
                p2 = DQNAgents.LGAN_DQNAgent(params)
                p2.load_model("outputs/Go_LGAN_DQN_24_mean_3D_seed_0/go(board_size=7,komi=4.5)_p2_round_2/model.pth")
                agent_list_p2.append(p2)
            case 'LGAN_A2C':
                p1 = PGAgents.LGAN_A2CAgent(params)
                p1.load_model("outputs/Go_LGAN_A2C_24_mean_seed_0/go(board_size=7,komi=4.5)_p1_round_6/model.pth")
                agent_list_p1.append(p1)
                p2 = PGAgents.LGAN_A2CAgent(params)
                p2.load_model("outputs/Go_LGAN_A2C_24_mean_seed_0/go(board_size=7,komi=4.5)_p2_round_7/model.pth")
                agent_list_p2.append(p2)
            case 'RES_DQN':
                p1 = DQNAgents.ResNetDQNAgent(params)
                p1.load_model("outputs/Go_RES_DQN_seed_0/go(board_size=7,komi=4.5)_p1_round_1/model.pth")
                agent_list_p1.append(p1)
                p2 = DQNAgents.ResNetDQNAgent(params)
                p2.load_model("outputs/Go_RES_DQN_seed_0/go(board_size=7,komi=4.5)_p2_round_1/model.pth")
                agent_list_p2.append(p2)
            case 'RES_A2C':
                p1 = PGAgents.ResNetA2CAgent(params)
                p1.load_model("outputs/Go_RES_A2C_seed_0/go(board_size=7,komi=4.5)_p1_round_6/model.pth")
                agent_list_p1.append(p1)
                p2 = PGAgents.ResNetA2CAgent(params)
                p2.load_model("outputs/Go_RES_A2C_seed_0/go(board_size=7,komi=4.5)_p2_round_6/model.pth")
                agent_list_p2.append(p2)
    for p1 in agent_list_p1:
        p1.player_id = 0
        for p2 in agent_list_p2:
            if p1
            p2.player_id = 1
            test_p1_vs_p2([p1, p2], env, test_group_size=1000, verbose=True)
        
if __name__ == "__main__":
    round_robin_go()
    print("Round Robin Go experiment completed.")
