import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from open_spiel.python import rl_agent

class PPOAgent:
    def __init__(self, params: dict = None):
        """
        params: dict, contains the parameters for the agent and the experiment
        """
        print("Initializing " + type(self).__name__ + " with parameters:", params)
        self.params = params
        self.player_id = params.get('player_id', 0)
        self.num_actions = params.get('num_actions', 4)
        self.state_dim = params.get('state_dim', 16)
        self.gamma = params.get('gamma', 0.99)
        self.lam = params.get('lam', 0.95)
        self.batch_size = params.get('batch_size', 128)
        self.minibatch_size = params.get('minibatch_size', 32)
        self.d_model = params.get('d_model', 128)
        self.lr = params.get('lr', 0.001)
        self.lr_end = 0.3 * self.lr
        self.lr_decay_duration = 10 * params.get('train_group_size', 1000)
        
        self.device = torch.device(params.get('device', "cuda"))
        self.entropy_coef = params.get('entropy_coef', 0.001)
        self.value_loss_coef = params.get('value_loss_coef', 0.5)

        self.ppo_epoches = params.get('ppo_epoches', 10)
        self.value_clip_coef = params.get('value_clip_coef', 0.2)
        self.clip_coef = params.get('clip_coef', 0.2)

        self.train_counter = 0
        self.trajectory_buffer = self.TrajectoryBuffer()
        self.train_buffer = self.PGTrainBuffer(self.state_dim, self.num_actions, self.batch_size, self.device)
        self.critic_loss = nn.MSELoss()
        self.value_loss_history = []
        self.policy_loss_history = []
        self.entropy_loss_history = []
        self.ratio_history = []
        self.prediction_history = []
        self.return_history = []
        self.new_probs_history = []
        self.clip_fraction_history = []

    def get_lr(self, power=1.0):
        decay_steps = min(self.train_counter, self.lr_decay_duration)
        decayed_lr = (self.lr_end + (self.lr - self.lr_end) * (1 - decay_steps / self.lr_decay_duration) ** power)
        return decayed_lr
    
    def get_clip(self, power=1.0):
        # not in use
        decay_steps = min(self.train_counter, self.lr_decay_duration)
        decayed_clip_coef = (self.clip_coef_end + (self.clip_coef - self.clip_coef_end) * (1 - decay_steps / self.lr_decay_duration) ** power)
        return decayed_clip_coef
    
    def learn(self):
        '''
        if self.train_buffer.count < self.batch_size:
            return
        for param_group in self.optimizer.param_groups:
            pass
            param_group["lr"] = self.get_lr()
        self.train_counter += 1
        '''
        states_raw, actions_raw, old_state_values_raw, advantages_raw, returns_raw, old_log_probs_raw = self.train_buffer.get()
        advantages_raw = (advantages_raw - advantages_raw.mean()) / (advantages_raw.std() + 1e-8)
        for _ in range(self.ppo_epoches):
            perm = torch.randperm(self.batch_size)
            for i in range(0, self.batch_size, self.minibatch_size):
                idx = perm[i:i+self.minibatch_size]
                states = states_raw[idx]
                actions = actions_raw[idx]
                old_state_values = old_state_values_raw[idx]
                advantages = advantages_raw[idx]
                returns = returns_raw[idx]
                old_log_probs = old_log_probs_raw[idx]

                logits, pred_state_values = self.net(states)
                value_clipped = old_state_values + torch.clamp(
                    pred_state_values - old_state_values, -self.value_clip_coef, self.value_clip_coef
                    )
                critic_loss_clipped = self.critic_loss(value_clipped, returns)
                critic_loss_unclipped = self.critic_loss(pred_state_values.view(-1), returns)
                critic_loss = torch.max(critic_loss_unclipped, critic_loss_clipped)
                legal_mask = states[:, :self.num_actions]
                masked_logits = logits.masked_fill(legal_mask == 0, -1e9)
                action_categorical = torch.distributions.Categorical(logits = masked_logits)
                new_log_probs = action_categorical.log_prob(actions)
                ratio = torch.exp(new_log_probs - old_log_probs.detach())
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_coef, 1 + self.clip_coef) * advantages
                actor_loss = -torch.min(surr1, surr2).mean()
                actor_loss = -(new_log_probs * advantages.detach()).mean()
                entropy_loss = - self.entropy_coef * action_categorical.entropy().mean()
                critic_loss = self.value_loss_coef * critic_loss
                loss = actor_loss + critic_loss + entropy_loss
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=0.5)
                self.optimizer.step()
                '''
                with torch.no_grad():
                    self.value_loss_history.append(critic_loss.cpu().detach().numpy().item())
                    self.policy_loss_history.append(actor_loss.item())
                    self.entropy_loss_history.append(entropy_loss.cpu().detach().numpy().item())
                    self.prediction_history.append(pred_state_values.mean().item())
                    self.return_history.append(returns.mean().item())
                    self.new_probs_history.append(torch.exp(new_log_probs).mean().item())
                    '''
        self.train_buffer.clear()

    def finish_trajectory(self):
        if len(self.trajectory_buffer.states) == 0:
            return
        traj_states, traj_actions, traj_rewards, traj_log_prob_tensor_list = self.trajectory_buffer.get()
        states_tensor = torch.tensor(traj_states, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            _, state_values_tensor = self.net(states_tensor)
        state_values = state_values_tensor.cpu().numpy().flatten()
        advantages = np.zeros_like(state_values)
        gae = 0
        next_value = 0.0 
        for i in reversed(range(len(traj_rewards))):
            # δ_t = r_t + γ * V(s_{t+1}) - V(s_t)
            delta = traj_rewards[i] + self.gamma * next_value - state_values[i]
            # A_t = δ_t + γ * λ * A_{t+1}
            gae = delta + self.gamma * self.lam * gae
            advantages[i] = gae
            next_value = state_values[i]
        returns = advantages + state_values
        #advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8)

        for i in range(len(traj_states)):
            if self.train_buffer.add(traj_states[i], traj_actions[i], state_values[i], advantages[i], returns[i], traj_log_prob_tensor_list[i]):
                self.learn()
        self.trajectory_buffer.clear()

    def choose_action(self, state):
        legal_mask = state[:self.num_actions]
        #legal_actions = np.nonzero(legal_mask)[0]
        state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            logits, _ = self.net(state_tensor)
        legal_mask = torch.tensor(legal_mask).to(self.device)
        masked_logits = logits.masked_fill(legal_mask == 0, -1e9).flatten()
        dist = torch.distributions.Categorical(logits=masked_logits)
        action = dist.sample()
        #legal_prob = legal_prob / legal_prob.sum()
        #action = random.choices(range(self.num_actions), weights = legal_prob)
        return action, dist
        
    def step(self, time_step, is_evaluation=False):
        '''
        time_step: open_spiel.python.TimeStep
        '''
        # for OpenSpiel, TTT and Breakthrough, info_state[0] = info_state[1]
        # So if you want to self play, you might need to mirror the state
        # exchange info_state[0][0:25] and info_state[0][25:50] for Breakthrough
        # Right now, the agent can only play as player 0
        if time_step.last():
            if is_evaluation:
                return
            self.trajectory_buffer.set_reward(time_step.rewards[self.player_id])
            self.finish_trajectory()
            return 
        state = np.asarray(time_step.observations["info_state"][0])
        legal_actions = time_step.observations["legal_actions"][self.player_id]
        legal_mask = np.zeros(self.num_actions, dtype=np.float32)
        legal_mask[legal_actions] = 1.0
        state = np.concatenate((legal_mask, state), axis=0)  # (num_actions + state_dim, )
        action, dist = self.choose_action(state=state)
        log_prob = dist.log_prob(action).detach()
        if not is_evaluation:
            if len(self.trajectory_buffer.rewards) > 0 :
                self.trajectory_buffer.set_reward(time_step.rewards[self.player_id])
            self.trajectory_buffer.append(state, action.item(), log_prob)
        return rl_agent.StepOutput(action=action.item(), probs=dist.probs.cpu().detach().numpy())

    def save_model(self, path='model.pth'):
        torch.save(self.net.state_dict(), path)

    def load_model(self, path='model.pth'):
        self.net.load_state_dict(torch.load(path, map_location=self.device))
        self.net.to(self.device)

    class TrajectoryBuffer:
        def __init__(self):
            self.states = []
            self.actions = []
            self.rewards = []
            self.log_probs = []

        def append(self, state, action, log_prob, reward=0):
            self.states.append(state)
            self.actions.append(action)
            self.rewards.append(reward)
            self.log_probs.append(log_prob)
        
        def set_reward(self, reward):
            self.rewards[-1] = reward

        def clear(self):
            self.states.clear()
            self.actions.clear()
            self.rewards.clear()
            self.log_probs.clear()

        def get(self):
            return np.asarray(self.states), np.asarray(self.actions), np.asarray(self.rewards), self.log_probs
        
    class PGTrainBuffer:
        def __init__(self, state_dim, num_actions, batch_size, device = "cuda"):
            self.batch_size = batch_size
            self.device = device
            self.states = torch.zeros((batch_size, state_dim + num_actions), dtype=torch.float32).to(device)
            self.actions = torch.zeros((batch_size,), dtype=torch.int32).to(device)
            self.state_values = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.advantages = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.returns = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.log_probs = torch.zeros((batch_size,), dtype=torch.float32).to(device)
            self.count = torch.tensor(0, dtype=torch.int64).to(device)

        def add(self, state, action, state_val, advantage, ret, log_prob):
            if self.count >= self.batch_size:
                raise ValueError("Train buffer is full, cannot add more data")
            self.states[self.count] = torch.from_numpy(state).to(self.device)
            self.actions[self.count] = torch.tensor(action, dtype=torch.int64).to(self.device)
            self.state_values[self.count] = torch.tensor(state_val, dtype=torch.float32).to(self.device)
            self.advantages[self.count] = torch.tensor(advantage, dtype=torch.float32).to(self.device)
            self.returns[self.count] = torch.tensor(ret, dtype=torch.float32).to(self.device)
            self.log_probs[self.count] = log_prob
            self.count += 1
            if self.count == self.batch_size:
                return True
            return False

        def clear(self):
            self.states.zero_()
            self.actions.zero_()
            self.state_values.zero_()
            self.advantages.zero_()
            self.returns.zero_()
            self.log_probs.zero_()
            self.count = 0

        def get(self):
            return self.states, self.actions, self.state_values, self.advantages, self.returns, self.log_probs

class PPOAgent(PPOAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.net = self.ActorCriticNet(self.state_dim, self.num_actions, d_model=self.d_model).to(self.device)
        self.critic_loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    class ActorCriticNet(nn.Module):
        def __init__(self, state_dim, num_actions, d_model=128):
            self.num_actions = num_actions
            super().__init__()
            self.shared_base = nn.Sequential(
                nn.Linear(state_dim, d_model),
                nn.ReLU(),
                nn.Linear(d_model, d_model),
                nn.ReLU()
            )
            self.actor_head = nn.Linear(d_model, num_actions)
            self.critic_head = nn.Linear(d_model, 1)

        def forward(self, state):
            state = state[:, self.num_actions:]
            base_output = self.shared_base(state)
            logits = self.actor_head(base_output)
            state_value = self.critic_head(base_output).squeeze(-1)
            return logits, state_value
        
class CNNPPOAgent(PPOAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.num_layers = params.get('num_layers', 4)
        self.grid_size = int(math.sqrt((self.state_dim // self.num_layers)))
        self.net = self.CNNActorCriticNet(
            num_actions=self.num_actions,
            num_layers=self.num_layers,
            grid_size=self.grid_size,
            d_model=self.d_model
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)

    class CNNActorCriticNet(nn.Module):
        def __init__(self, num_actions, num_layers, grid_size, d_model=128):
            super().__init__()
            self.num_layers = num_layers
            self.grid_size = grid_size
            self.num_actions = num_actions
            self.conv_base = nn.Sequential(
                nn.Conv2d(num_layers, 32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Dropout(p=0.6),
                nn.Conv2d(32, 64, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Dropout(p=0.6)
            )
            conv_out_h = grid_size // 2
            conv_out_w = grid_size // 2
            flattened_size = 64 * conv_out_h * conv_out_w
            self.shared_fc = nn.Sequential(
                nn.Linear(flattened_size, d_model), 
                nn.ReLU(),
                nn.Dropout(p=0.6)
            )
            self.actor_head = nn.Linear(d_model, num_actions)
            self.critic_head = nn.Linear(d_model, 1)

        def forward(self, state):
            state = state[:, self.num_actions:]
            reshaped_state = state.view(-1, self.num_layers, self.grid_size, self.grid_size)
            conv_features = self.conv_base(reshaped_state)
            flattened_features = conv_features.view(conv_features.size(0), -1)
            shared_output = self.shared_fc(flattened_features)
            logits = self.actor_head(shared_output)
            state_value = self.critic_head(shared_output).squeeze(-1)
            return logits, state_value

class ResNetPPOAgent(PPOAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.in_channels = params.get('num_layers', 4)
        self.height = int(math.sqrt((self.state_dim // self.in_channels)))
        self.width = int(math.sqrt((self.state_dim // self.in_channels)))
        self.num_res_blocks = params.get('num_res_blocks', 6)
        
        self.net = self.ResNetActorCriticNet(
            num_actions=self.num_actions,
            in_channels=self.in_channels,
            height=self.height,
            width=self.width,
            num_res_blocks=self.num_res_blocks,
            d_model=self.d_model
        ).to(self.device)
        
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr)


    class ResNetActorCriticNet(nn.Module):
        def __init__(self, num_actions, in_channels, height, width, num_res_blocks, d_model=128):
            super().__init__()
            self.num_actions = num_actions
            self.in_channels = in_channels
            self.height = height
            self.width = width
            self.stem = nn.Sequential(
                nn.Conv2d(in_channels, d_model, kernel_size=3, padding=1),
                nn.BatchNorm2d(d_model),
                nn.ReLU()
            )
            res_blocks = [self.ResidualBlock(d_model) for _ in range(num_res_blocks)]
            self.res_body = nn.Sequential(*res_blocks)

            with torch.no_grad():
                dummy_input = torch.zeros(1, in_channels, height, width)
                dummy_output = self.res_body(self.stem(dummy_input))
                flattened_size = dummy_output.view(1, -1).size(1)

            self.actor_head = nn.Linear(flattened_size, num_actions)
            self.critic_head = nn.Linear(flattened_size, 1)
            
        def forward(self, flat_state):
            board_state = flat_state[:, self.num_actions:]
            reshaped_board = board_state.view(
                -1, self.in_channels, self.height, self.width
            )
            stem_out = self.stem(reshaped_board)
            res_out = self.res_body(stem_out)
            flattened_output = res_out.view(res_out.size(0), -1)
            logits = self.actor_head(flattened_output)
            state_value = self.critic_head(flattened_output)
            return logits, state_value
        
        class ResidualBlock(nn.Module):
            def __init__(self, num_filters):
                super().__init__()
                self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn1 = nn.BatchNorm2d(num_filters)
                self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn2 = nn.BatchNorm2d(num_filters)

            def forward(self, x):
                residual = x
                out = F.relu(self.bn1(self.conv1(x)))
                out = self.bn2(self.conv2(out))
                out += residual  # The skip connection
                out = F.relu(out)
                return out
