import math
import numpy as np
import random
import copy
from datetime import datetime
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from open_spiel.python.algorithms import mcts
from open_spiel.python.algorithms import mcts_agent
import seaborn as sns
import os
import torch
import pyspiel

def simulate(agents, env, group_size=10000, num_groups=1, is_evaluation=[False, False]):
    print("==============OpenSpiel Style Simulation==============")
    for i in range(len(agents)):
        agent = agents[i]
        print("Player %d" % (i + 1) + " : " + agent.__class__.__name__)
    total_rewards_p1 = []
    total_rewards_p2 = []
    score_p1_list = []
    game_length_list = []
    for i in range(num_groups):
        for _ in range(group_size):
            time_step = env.reset()
            turn_flag = 0
            while not time_step.last():
                agent_to_act = agents[time_step.observations["current_player"]]
                agent_output = agent_to_act.step(time_step, is_evaluation=
                                                 is_evaluation[time_step.observations["current_player"]])
                time_step = env.step([agent_output.action])
                turn_flag += 1
            # Episode is over, step all agents with final info state.
            total_rewards_p1.append(time_step.rewards[0])
            total_rewards_p2.append(time_step.rewards[1])
            for agent in agents:
                agent.step(time_step)
            game_length_list.append(turn_flag + 1)
        score_p1 = np.mean(total_rewards_p1[-group_size:])
        score_p1_list.append(score_p1)
        if hasattr(agents[0], 'get_epsilon'):
            out_str = "%d" % (i + 1) + " - %d" % group_size + " games :\tavg score_p1: %0.4f" % score_p1 + "\tavg game length %0.1f" % np.mean(game_length_list[-group_size:]) + " epsilon: %0.4f" % agents[0].get_epsilon()
        else:
            out_str = "%d" % (i + 1) + " - %d" % group_size + " games :\tavg score_p1: %0.4f" % score_p1 + "\tavg game length %0.1f" % np.mean(game_length_list[-group_size:])
        print(out_str)
    return score_p1_list

def test_p1_vs_p2(agents, env, test_group_size=1000, verbose=False):
    if verbose:
        out_str = "Testing " + " vs ".join([agent.__class__.__name__ for agent in agents])
        print(out_str)
    total_rewards_p1 = []
    total_rewards_p2 = []
    for _ in range(test_group_size):
        time_step = env.reset()
        while not time_step.last():
            agent_to_act = agents[time_step.observations["current_player"]]
            agent_output = agent_to_act.step(time_step, is_evaluation=True)
            time_step = env.step([agent_output.action])
        # Episode is over, step all agents with final info state.
        total_rewards_p1.append(time_step.rewards[0])
        total_rewards_p2.append(time_step.rewards[1])
    winrate_p1 = np.sum(np.array(total_rewards_p1) > 0) / test_group_size
    winrate_p2 = np.sum(np.array(total_rewards_p2) > 0) / test_group_size
    if verbose:
        print("tested %d" % test_group_size + " games, avg winrate of p1 %0.4f" % winrate_p1 
          + " avg winrate of p2 %0.4f" % winrate_p2)
    return [winrate_p1, winrate_p2]

def draw_heatmap(agent, state, num, add_file_name=None):
    plt.rcParams['font.sans-serif'].insert(0, 'Helvetica')
    plt.rcParams['font.size'] = 12
    plt.rcParams['axes.labelsize'] = 11
    path = "outputs/" + add_file_name
    file_name = "heatmap"+ str(num) + ".png"
    assert len(state) == 36, "Expected 36-dim state vec for TTT"
    board_layers = np.asarray(state[9:]).reshape(3, 3, 3)  # (channel, row, col)
    fig, axes = plt.subplots(5, 6, figsize=(19.2, 10.8))
    titles = ["Legal Mask", "Self", "Opponent"]
    for i in range(3):
        sns.heatmap(board_layers[i], ax=axes[0, i], annot=True, cbar=False,
                    square=True, cmap="Blues", vmin=0, vmax=1)
        axes[0, i].set_title(titles[i])
        axes[0, i].set_ylabel("Current State")
        axes[0, i].set_xticklabels([])
        axes[0, i].set_yticklabels([])
    attn_map = agent.policy_net.get_attn_weights(torch.tensor(state, dtype=torch.float32).to(agent.device))
    state = torch.tensor(state, dtype=torch.float32).to(agent.device)
    q_val = agent.policy_net(state).detach().cpu().numpy().flatten()
    vmax = attn_map.max()
    for action_idx in range(9):
        for layer in range(3):
            data = attn_map[action_idx, layer * 9: (layer + 1) * 9].reshape(3, 3)
            if action_idx % 2 == 0:
                ax = axes[action_idx // 2, layer + 3]
            else:
                ax = axes[action_idx // 2 + 1, layer]
            sns.heatmap(data, ax=ax, annot=True, cbar=False,
                        square=True, cmap="YlOrRd", vmin=0, vmax=vmax)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            if action_idx == 0:
                ax.set_title(titles[layer])
            if layer == 0:
                if q_val[action_idx] > 0:
                    color = 'blue'
                elif q_val[action_idx] == -torch.inf:
                    color = 'black'
                else:
                    color = 'red'
                ax.set_ylabel(f"Action {action_idx} (Q_val={q_val[action_idx]:.2f})", color=color)
    plt.suptitle("Current TicTacToe State (top) and Attention for All Actions (rows 1-9: actions, cols: layers)",
                    y=1.02)
    plt.tight_layout()
    os.makedirs(path, exist_ok=True)
    save_path = os.path.join(path, file_name)
    plt.savefig(save_path, dpi=600, bbox_inches='tight')
    plt.close()

def save_information_TTT(agent, add_file_name=None):
    state = [0, 1, 1, 1, 0, 1, 0, 0, 1,
             0, 1, 1, 1, 0, 1, 0, 0, 1,
             0, 0, 0, 0, 1, 0, 0, 1, 0,
             1, 0, 0, 0, 0, 0, 1, 0, 0]
    draw_heatmap(agent, state, 0, add_file_name)
    state = [0, 1, 0, 1, 1, 1, 1, 0, 0,
             0, 1, 0, 1, 1, 1, 1, 0, 0,
             1, 0, 1, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 1, 1]
    draw_heatmap(agent, state, 1, add_file_name)
    state = [1, 1, 0, 0, 1, 1, 0, 1, 0,
             1, 1, 0, 0, 1, 1, 0, 1, 0,
             0, 0, 1, 0, 0, 0, 0, 0, 1,
             0, 0, 0, 1, 0, 0, 1, 0, 0]
    draw_heatmap(agent, state, 2, add_file_name)
    state = [0, 0, 1, 1, 1, 1, 0, 1, 0,
             0, 0, 1, 1, 1, 1, 0, 1, 0,
             0, 0, 0, 0, 0, 0, 1, 0, 1,
             1, 1, 0, 0, 0, 0, 0, 0, 0]
    draw_heatmap(agent, state, 3, add_file_name)
    state = [0, 1, 0, 1, 1, 0, 0, 1, 1,
             0, 1, 0, 1, 1, 0, 0, 1, 1,
             1, 0, 0, 0, 0, 0, 1, 0, 0,
             0, 0, 1, 0, 0, 1, 0, 0, 0]
    draw_heatmap(agent, state, 4, add_file_name)

def save_information_breakthrough(agent, add_file_name=None):

    path = "outputs/" + add_file_name
    os.makedirs(path, exist_ok=True)
    save_path = os.path.join(path, "model.pth")
    agent.save_model(save_path)
    if "DQN" in agent.__class__.__name__:
        file_name = "q_info.png"
        save_path = os.path.join(path, file_name)
        # Plot Q info
        fig, axes = plt.subplots(2, 2, figsize=(19.2, 10.8))
        axes = axes.flatten()
        if hasattr(agent, 'loss_history'):
            axes[0].plot(agent.loss_history)
            axes[0].set_title('Loss History')
            axes[0].set_xlabel('Step')
            axes[0].set_ylabel('Loss')
        if hasattr(agent, 'q_max_history'):
            axes[1].plot(agent.q_max_history)
            axes[1].set_title('Q Max History')
            axes[1].set_xlabel('Step')
            axes[1].set_ylabel('Q Max')
        if hasattr(agent, 'q_min_history'):
            axes[2].plot(agent.q_min_history)
            axes[2].set_title('Q Min History')
            axes[2].set_xlabel('Step')
            axes[2].set_ylabel('Q Min')
        if hasattr(agent, 'q_mean_history'):
            axes[3].plot(agent.q_mean_history)
            axes[3].set_title('Q Mean History')
            axes[3].set_xlabel('Step')
            axes[3].set_ylabel('Q Mean')
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()    
    else:
        file_name = "loss_info.png"
        save_path = os.path.join(path, file_name)
        fig, axes = plt.subplots(4, 2, figsize=(19.2, 10.8))
        axes = axes.flatten()
        if hasattr(agent, 'value_loss_history'):
            axes[0].plot(agent.value_loss_history)
            axes[0].set_title('value loss History')
            axes[0].set_xlabel('Step')
            axes[0].set_ylabel('loss')
        if hasattr(agent, 'policy_loss_history'):
            axes[1].plot(agent.policy_loss_history)
            axes[1].set_title('policy loss History')
            axes[1].set_xlabel('Step')
            axes[1].set_ylabel('loss')
        if hasattr(agent, 'entropy_loss_history'):
            axes[2].plot(agent.entropy_loss_history)
            axes[2].set_title('entropy loss History')
            axes[2].set_xlabel('Step')
            axes[2].set_ylabel('loss')
        if hasattr(agent, 'ratio_history'):
            axes[3].plot(agent.ratio_history)
            axes[3].set_title('ratio History')
            axes[3].set_xlabel('Step')
            axes[3].set_ylabel('ratio')
        if hasattr(agent, 'prediction_history'):
            axes[4].plot(agent.prediction_history)
            axes[4].set_title('prediction_history')
            axes[4].set_xlabel('Step')
            axes[4].set_ylabel('~')
        if hasattr(agent, 'return_history'):
            axes[5].plot(agent.return_history)
            axes[5].set_title('return_history')
            axes[5].set_xlabel('Step')
            axes[5].set_ylabel('~')
        if hasattr(agent, 'new_probs_history'):
            axes[6].plot(agent.new_probs_history)
            axes[6].set_title('new_probs_history')
            axes[6].set_xlabel('Step')
            axes[6].set_ylabel('~')
        if hasattr(agent, 'clip_fraction_history'):
            axes[7].plot(agent.clip_fraction_history)
            axes[7].set_title('clip_fraction_history')
            axes[7].set_xlabel('Step')
            axes[7].set_ylabel('~')
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
