import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from open_spiel.python import rl_agent

class PGAgent:
    def __init__(self, params: dict = None):
        """
        params: dict, contains the parameters for the agent and the experiment
        """
        print("Initializing " + type(self).__name__ + " with parameters:", params)
        self.params = params
        self.player_id = params.get('player_id', 0)
        self.num_actions = params.get('num_actions', 4)
        self.state_dim = params.get('state_dim', 16)
        self.gamma = params.get('gamma', 0.99)
        self.lam = params.get('lam', 0.95)
        self.batch_size = params.get('batch_size', 128)
        self.minibatch_size = params.get('minibatch_size', 32)
        self.d_model = params.get('d_model', 128)
        self.lr = params.get('lr', 0.001)
        self.lr_end = 0.3 * self.lr
        self.lr_decay_duration = 10 * params.get('train_group_size', 1000)
        
        self.device = torch.device(params.get('device', "cuda"))
        self.entropy_coef = params.get('entropy_coef', 0.001)
        self.value_loss_coef = params.get('value_loss_coef', 0.5)

        self.train_counter = 0
        self.trajectory_buffer = self.TrajectoryBuffer()
        self.train_buffer = self.PGTrainBuffer(self.state_dim, self.num_actions, self.batch_size, self.device)
        self.critic_loss = nn.MSELoss()
        self.value_loss_history = []
        self.policy_loss_history = []
        self.entropy_loss_history = []
        self.ratio_history = []
        self.prediction_history = []
        self.return_history = []
        self.new_probs_history = []
        self.clip_fraction_history = []

    def get_lr(self, power=1.0):
        decay_steps = min(self.train_counter, self.lr_decay_duration)
        decayed_lr = (self.lr_end + (self.lr - self.lr_end) * (1 - decay_steps / self.lr_decay_duration) ** power)
        return decayed_lr
    
    def get_clip(self, power=1.0):
        # not in use
        decay_steps = min(self.train_counter, self.lr_decay_duration)
        decayed_clip_coef = (self.clip_coef_end + (self.clip_coef - self.clip_coef_end) * (1 - decay_steps / self.lr_decay_duration) ** power)
        return decayed_clip_coef
    
    def learn(self):
        '''
        if self.train_buffer.count < self.batch_size:
            return
        for param_group in self.optimizer.param_groups:
            pass
            param_group["lr"] = self.get_lr()
        self.train_counter += 1
        '''
        states, actions, old_state_values, advantages, returns, old_log_probs = self.train_buffer.get()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        logits, pred_state_values = self.net(states)
        critic_loss = self.critic_loss(pred_state_values.view(-1), returns)
        legal_mask = states[:, :self.num_actions]
        masked_logits = logits.masked_fill(legal_mask == 0, -1e9)
        action_categorical = torch.distributions.Categorical(logits = masked_logits)
        new_log_probs = action_categorical.log_prob(actions)
        actor_loss = -(new_log_probs * advantages.detach()).mean()
        entropy_loss = - self.entropy_coef * action_categorical.entropy().mean()
        critic_loss = self.value_loss_coef * critic_loss
        loss = actor_loss + critic_loss + entropy_loss
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=0.5)
        self.optimizer.step()
        with torch.no_grad():
            self.value_loss_history.append(critic_loss.cpu().detach().numpy().item())
            self.policy_loss_history.append(actor_loss.item())
            self.entropy_loss_history.append(entropy_loss.cpu().detach().numpy().item())
            self.prediction_history.append(pred_state_values.mean().item())
            self.return_history.append(returns.mean().item())
            self.new_probs_history.append(torch.exp(new_log_probs).mean().item())
        self.train_buffer.clear()

    def finish_trajectory(self):
        if len(self.trajectory_buffer.states) == 0:
            return
        traj_states, traj_actions, traj_rewards, traj_log_prob_tensor_list = self.trajectory_buffer.get()
        states_tensor = torch.tensor(traj_states, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            _, state_values_tensor = self.net(states_tensor)
        state_values = state_values_tensor.cpu().numpy().flatten()
        advantages = np.zeros_like(state_values)
        gae = 0
        next_value = 0.0 
        for i in reversed(range(len(traj_rewards))):
            # δ_t = r_t + γ * V(s_{t+1}) - V(s_t)
            delta = traj_rewards[i] + self.gamma * next_value - state_values[i]
            # A_t = δ_t + γ * λ * A_{t+1}
            gae = delta + self.gamma * self.lam * gae
            advantages[i] = gae
            next_value = state_values[i]
        returns = advantages + state_values
        #advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)

        for i in range(len(traj_states)):
            if self.train_buffer.add(traj_states[i], traj_actions[i], state_values[i], advantages[i], returns[i], traj_log_prob_tensor_list[i]):
                self.learn()
        self.trajectory_buffer.clear()

    def choose_action(self, state):
        legal_mask = state[:self.num_actions]
        #legal_actions = np.nonzero(legal_mask)[0]
        state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            logits, _ = self.net(state_tensor)
        legal_mask = torch.tensor(legal_mask).to(self.device)
        masked_logits = logits.masked_fill(legal_mask == 0, -1e9).flatten()
        dist = torch.distributions.Categorical(logits=masked_logits)
        action = dist.sample()
        #legal_prob = legal_prob / legal_prob.sum()
        #action = random.choices(range(self.num_actions), weights = legal_prob)
        return action, dist
        
    def step(self, time_step, is_evaluation=False):
        '''
        time_step: open_spiel.python.TimeStep
        '''
        # for OpenSpiel, TTT and Breakthrough, info_state[0] = info_state[1]
        # So if you want to self play, you might need to mirror the state
        # exchange info_state[0][0:25] and info_state[0][25:50] for Breakthrough
        # Right now, the agent can only play as player 0
        if time_step.last():
            if is_evaluation:
                return
            self.trajectory_buffer.set_reward(time_step.rewards[self.player_id])
            self.finish_trajectory()
            return 
        state = np.asarray(time_step.observations["info_state"][0])
        legal_actions = time_step.observations["legal_actions"][self.player_id]
        legal_mask = np.zeros(self.num_actions, dtype=np.float32)
        legal_mask[legal_actions] = 1.0
        state = np.concatenate((legal_mask, state), axis=0)  # (num_actions + state_dim, )
        action, dist = self.choose_action(state=state)
        log_prob = dist.log_prob(action).detach()
        if not is_evaluation:
            if len(self.trajectory_buffer.rewards) > 0 :
                self.trajectory_buffer.set_reward(time_step.rewards[self.player_id])
            self.trajectory_buffer.append(state, action.item(), log_prob)
        return rl_agent.StepOutput(action=action.item(), probs=dist.probs.cpu().detach().numpy())

    def save_model(self, path='model.pth'):
        torch.save(self.net.state_dict(), path)

    def load_model(self, path='model.pth'):
        self.net.load_state_dict(torch.load(path, map_location=self.device))
        self.net.to(self.device)

    class TrajectoryBuffer:
        def __init__(self):
            self.states = []
            self.actions = []
            self.rewards = []
            self.log_probs = []

        def append(self, state, action, log_prob, reward=0):
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.log_probs.append(log_prob)
        
        def set_reward(self, reward):
            self.rewards[-1] = reward

        def clear(self):
            self.states.clear()
            self.actions.clear()
            self.rewards.clear()
            self.log_probs.clear()

        def get(self):
            return np.asarray(self.states), np.asarray(self.actions), np.asarray(self.rewards), self.log_probs
        
    class PGTrainBuffer:
        def __init__(self, state_dim, num_actions, batch_size, device = "cuda"):
            self.batch_size = batch_size
            self.device = device
            self.states = torch.zeros((batch_size, state_dim + num_actions), dtype=torch.float32).to(device)
            self.actions = torch.zeros((batch_size,), dtype=torch.int32).to(device)
            self.state_values = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.advantages = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.returns = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.log_probs = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.count = torch.tensor(0, dtype=torch.int64).to(device)

        def add(self, state, action, state_val, advantage, ret, log_prob):
            if self.count >= self.batch_size:
                raise ValueError("Train buffer is full, cannot add more data")
            self.states[self.count] = torch.from_numpy(state).to(self.device)
            self.actions[self.count] = torch.tensor(action, dtype=torch.int64).to(self.device)
            self.state_values[self.count] = torch.tensor(state_val, dtype=torch.float32).to(self.device)
            self.advantages[self.count] = torch.tensor(advantage, dtype=torch.float32).to(self.device)
            self.returns[self.count] = torch.tensor(ret, dtype=torch.float32).to(self.device)
            self.log_probs[self.count] = log_prob
            self.count += 1
            if self.count == self.batch_size:
                return True
            return False

        def clear(self):
            self.states.zero_()
            self.actions.zero_()
            self.state_values.zero_()
            self.advantages.zero_()
            self.returns.zero_()
            self.log_probs.zero_()
            self.count = 0

        def get(self):
            return self.states, self.actions, self.state_values, self.advantages, self.returns, self.log_probs

class A2CAgent(PGAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.net = self.ActorCriticNet(self.state_dim, self.num_actions, d_model=self.d_model).to(self.device)
        self.critic_loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    class ActorCriticNet(nn.Module):
        def __init__(self, state_dim, num_actions, d_model=128):
            self.num_actions = num_actions
            super().__init__()
            self.shared_base = nn.Sequential(
                nn.Linear(state_dim, d_model),
                nn.ReLU(),
                nn.Linear(d_model, d_model),
                nn.ReLU()
            )
            self.actor_head = nn.Linear(d_model, num_actions)
            self.critic_head = nn.Linear(d_model, 1)

        def forward(self, state):
            state = state[:, self.num_actions:]
            base_output = self.shared_base(state)
            logits = self.actor_head(base_output)
            state_value = self.critic_head(base_output).squeeze(-1)
            return logits, state_value
        
class CNNA2CAgent(A2CAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.num_layers = params.get('num_layers', 4)
        self.grid_size = int(math.sqrt((self.state_dim // self.num_layers)))
        self.net = self.CNNActorCriticNet(
            num_actions=self.num_actions,
            num_layers=self.num_layers,
            grid_size=self.grid_size,
            d_model=self.d_model
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    class CNNActorCriticNet(nn.Module):
        def __init__(self, num_actions, num_layers, grid_size, d_model=128):
            super().__init__()
            self.num_layers = num_layers
            self.grid_size = grid_size
            self.num_actions = num_actions
            self.conv_base = nn.Sequential(
                nn.Conv2d(num_layers, 32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Dropout(p=0.6),
                nn.Conv2d(32, 64, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Dropout(p=0.6)
            )
            conv_out_h = grid_size // 2
            conv_out_w = grid_size // 2
            flattened_size = 64 * conv_out_h * conv_out_w
            self.shared_fc = nn.Sequential(
                nn.Linear(flattened_size, d_model), 
                nn.ReLU(),
                nn.Dropout(p=0.6)
            )
            self.actor_head = nn.Linear(d_model, num_actions)
            self.critic_head = nn.Linear(d_model, 1)

        def forward(self, state):
            state = state[:, self.num_actions:]
            reshaped_state = state.view(-1, self.num_layers, self.grid_size, self.grid_size)
            conv_features = self.conv_base(reshaped_state)
            flattened_features = conv_features.view(conv_features.size(0), -1)
            shared_output = self.shared_fc(flattened_features)
            logits = self.actor_head(shared_output)
            state_value = self.critic_head(shared_output).squeeze(-1)
            return logits, state_value

class ResNetA2CAgent(A2CAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.in_channels = params.get('num_layers', 4)
        self.height = int(math.sqrt((self.state_dim // self.in_channels)))
        self.width = int(math.sqrt((self.state_dim // self.in_channels)))
        self.num_res_blocks = params.get('num_res_blocks', 6)
        self.net = self.ResNetActorCriticNet(
            num_actions=self.num_actions,
            in_channels=self.in_channels,
            height=self.height,
            width=self.width,
            num_res_blocks=self.num_res_blocks,
            d_model=self.d_model
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    class ResNetActorCriticNet(nn.Module):
        def __init__(self, num_actions, in_channels, height, width, num_res_blocks, d_model=128):
            super().__init__()
            self.num_actions = num_actions
            self.in_channels = in_channels
            self.height = height
            self.width = width
            self.stem = nn.Sequential(
                nn.Conv2d(in_channels, d_model, kernel_size=3, padding=1),
                nn.BatchNorm2d(d_model),
                nn.ReLU()
            )
            res_blocks = [self.ResidualBlock(d_model) for _ in range(num_res_blocks)]
            self.res_body = nn.Sequential(*res_blocks)

            with torch.no_grad():
                dummy_input = torch.zeros(1, in_channels, height, width)
                dummy_output = self.res_body(self.stem(dummy_input))
                flattened_size = dummy_output.view(1, -1).size(1)

            self.actor_head = nn.Linear(flattened_size, num_actions)
            self.critic_head = nn.Linear(flattened_size, 1)
            
        def forward(self, flat_state):
            board_state = flat_state[:, self.num_actions:]
            reshaped_board = board_state.view(
                -1, self.in_channels, self.height, self.width
            )
            stem_out = self.stem(reshaped_board)
            res_out = self.res_body(stem_out)
            flattened_output = res_out.view(res_out.size(0), -1)
            logits = self.actor_head(flattened_output)
            state_value = self.critic_head(flattened_output)
            return logits, state_value
        
        class ResidualBlock(nn.Module):
            def __init__(self, num_filters):
                super().__init__()
                self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn1 = nn.BatchNorm2d(num_filters)
                self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn2 = nn.BatchNorm2d(num_filters)

            def forward(self, x):
                residual = x
                out = F.relu(self.bn1(self.conv1(x)))
                out = self.bn2(self.conv2(out))
                out += residual  # The skip connection
                out = F.relu(out)
                return out

class LGAN_A2CAgent(PGAgent):
    def __init__(self, params):
        super().__init__(params)
        self.num_heads = params.get('num_heads', 8)
        self.k_embedding_dim = params.get('k_embedding_dim', 16)
        self.net = self.LegalAttentionNet_PG(input_dim=self.state_dim, num_actions=self.num_actions,
                    d_model=self.d_model, num_heads=self.num_heads, k_embedding_dim=self.k_embedding_dim, 
                    num_layers=params.get('num_layers', 1), average_func=params.get('average_func', 'mlp')).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    def choose_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            logits, _ = self.net(state_tensor)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action, dist

       
    class LegalAttentionNet_PG(nn.Module):
        def __init__(self, input_dim, num_actions, d_model=512, num_heads=8, k_embedding_dim=16, num_layers=1, average_func='mlp'):
            super().__init__()
            assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
            self.num_actions = num_actions
            self.input_dim = input_dim
            self.d_model = d_model
            self.k_proj = nn.Linear(1 + k_embedding_dim, d_model)
            self.temperature1 = math.sqrt(d_model)
            self.d_head = d_model // num_heads
            self.num_heads = num_heads
            self.q_embedding = nn.Embedding(num_actions, d_model)
            self.register_buffer('action_indices', torch.arange(num_actions))
            self.head_mlp = nn.Linear(num_heads, 1, bias=False)
            nn.init.constant_(self.head_mlp.weight, 1.0)
            if num_layers == 0:
                raise NotImplementedError("learnable PE does not work, forbidden for now.")
                self.k_embedding = nn.Embedding(input_dim, k_embedding_dim).weight
            elif num_layers == 1:
                self.register_buffer("k_embedding", self.build_sincos_pe(input_dim, k_embedding_dim))
            else:
                self.register_buffer("k_embedding", self.build_sincos_pe_3d(input_dim, k_embedding_dim, num_layers))
            if average_func == 'mean':
                self.average_func = lambda x: torch.mean(x, dim=-1)
            elif average_func == 'mlp':
                self.average_func = lambda x: self.head_mlp(torch.relu(x)).squeeze(-1)
            self.critic_1 = nn.Linear(d_model, 1)
            self.critic_2 = nn.Linear(input_dim, 1)
            
        def forward(self, x):
            # B: batch size, batch_size
            # S: length of the input, input_dim
            # A: number of actions, num_action
            # d_model: number of latent variables
            # h: number of heads, num_heads
            # d_head: number of latent variables for each head
            if x.dim() == 1:
                x = x.unsqueeze(0)                                                                          # inference
            batch_size = x.size(0)
            legal_mask = x[:, :self.num_actions]                                                            # (B, A)
            Q = self.q_embedding(self.action_indices).unsqueeze(0).expand(batch_size, -1, -1)               # (B, A, d_model)
            Q = Q * legal_mask.unsqueeze(-1)                                                                # (B, A, d_model)
            Q = Q.reshape(batch_size, self.num_actions, self.num_heads, self.d_head).transpose(1, 2)        # (B, h, A, d_head)
            K = x[:, self.num_actions:].unsqueeze(-1)                                                       # (B, S, 1)
            K_pos_vectors = self.k_embedding.repeat(batch_size, 1, 1)                                       # (B, S, k_embedding_dim)
            K = torch.cat([K, K_pos_vectors], dim=-1)                                                       # (B, S, 1 + k_embedding_dim)
            K = self.k_proj(K)                                                                              # (B, S, d_model)
            K2 = K.reshape(batch_size, self.num_heads, self.d_head, -1)                                      # (B, h, d_head, S)
            attn_scores = torch.matmul(Q, K2) / self.temperature1                                            # (B, h, A, S)
            attn_scores = attn_scores.permute(0, 2, 3, 1)                                                   # (B, A, S, h)
            attn_weights = torch.relu(self.average_func(attn_scores))                                       # (B, A, S)
            V = x[:, self.num_actions:].unsqueeze(-1)                                                       # (B, S, 1)
            z = torch.relu(torch.matmul(attn_weights, V).squeeze(-1))                                       # (B, A)
            z = z + legal_mask * 1e-8
            logits = torch.log(z)                                                                            # (B, A)
            state_value = torch.relu(self.critic_1(K)).squeeze(-1)                                      # (B, S)
            state_value = self.critic_2(state_value).squeeze(-1)                                        # (B,)
            return logits, state_value

        def get_attn_weights(self, x):
            # output shape must be (a, s)
            # shape(x) = (s, )
            legal_mask = x[:self.num_actions].unsqueeze(-1)                                                 # (A, 1)
            action_indices = torch.arange(self.num_actions, device=x.device)                                # (B, a)
            Q = self.q_embedding(action_indices)
            Q = Q * legal_mask
            Q = Q.reshape(self.num_actions, self.num_heads, self.d_head).permute(1, 0, 2)                   # (num_heads, a, d_head)
            K = x[self.num_actions:].unsqueeze(-1)                                                          # (s, 1)
            K_pos_vectors = self.k_embedding.squeeze(0)
            K = torch.cat([K, K_pos_vectors], dim=-1)                                                       # (s, 1 + k_embedding_dim)
            K = self.k_proj(K)                                                                              # (s, d_model)
            K = K.reshape(self.num_heads, self.d_head, -1)                                                  # (h, d_head, S)
            attn_scores = torch.matmul(Q, K) / self.temperature1                                            # (h, a, s)
            attn_scores = attn_scores.permute(1, 2, 0)                                                      # (a, s, h)
            attn_weights = torch.relu(self.average_func(attn_scores))                                       # (a, s)
            return attn_weights.detach().cpu().numpy()
        
        def build_sincos_pe(self, input_dim, d_pe):
            assert d_pe % 2 == 0, "d_pe must be even for sin+cos"
            pos = torch.arange(input_dim).unsqueeze(1)                      # (input_dim, 1)
            i = torch.arange(d_pe // 2).unsqueeze(0)                        # (1, d_pe//2)
            base = 2 * input_dim
            #base = 10000
            angle = pos / (base ** (2 * i / d_pe))                          # (input_dim, d_pe//2)
            pe = torch.zeros(input_dim, d_pe)
            pe[:, 0::2] = torch.sin(angle)
            pe[:, 1::2] = torch.cos(angle)
            return pe.unsqueeze(0)                                          # (1, input_dim, d_pe)
        
        def build_sincos_pe_3d(self, input_dim, d_pe, num_layers):
            # Assumes input is flattened in [layer-row-col] order, matches spatial position.
            assert d_pe > 4 , "check d_pe."
            assert (d_pe - 4) % 2 == 0 , "check d_pe."
            assert input_dim % num_layers == 0, "input_dim must be divisible by num_layer."
            layer_size = input_dim // num_layers
            grid_size_float = math.sqrt(layer_size)
            assert grid_size_float == int(grid_size_float), "Each layer's size must be a perfect square."
            grid_size = int(grid_size_float)

            d_layer = 4
            d_remain = d_pe - d_layer
            d_x = d_remain // 2
            d_y = d_remain // 2
            pe_z = self.build_sincos_pe(num_layers, d_layer).squeeze(0)   # (L, d_layer)
            pe_x = self.build_sincos_pe(grid_size, d_x).squeeze(0)        # (H, d_x)
            pe_y = self.build_sincos_pe(grid_size, d_y).squeeze(0)        # (W, d_y)
            pe_z = pe_z[:, None, None, :].expand(num_layers, grid_size, grid_size, d_layer)
            pe_x = pe_x[None, :, None, :].expand(num_layers, grid_size, grid_size, d_x)
            pe_y = pe_y[None, None, :, :].expand(num_layers, grid_size, grid_size, d_y)

            pe = torch.cat([pe_z, pe_x, pe_y], dim=-1)  # (L, H, W, d_pe)
            pe = pe.reshape(1, input_dim, d_pe)         # flatten to (1, input_dim, d_pe)
            return pe


