import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
        self.policy_rep = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]
        del self.policy_rep[:]

    def merge(self, other_buffer):
        self.actions.extend(other_buffer.actions)
        self.states.extend(other_buffer.states)
        self.logprobs.extend(other_buffer.logprobs)
        self.rewards.extend(other_buffer.rewards)
        self.is_terminals.extend(other_buffer.is_terminals)


class ActorCritic(nn.Module):
    def __init__(self, state_dim, joint_state_dim, action_dim, has_continuous_action_space, action_std_init, device):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        self.device = device

        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)
        # actor
        if has_continuous_action_space:
            self.actor = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
            )
        else:
            self.actor = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )
        # critic
        self.critic = nn.Sequential(
            nn.Linear(joint_state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        self.actor.to(device)
        self.critic.to(device)

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(self.device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def forward(self):
        raise NotImplementedError

    def act(self, state):
        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action, joint_state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)

            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(self.device)
            dist = MultivariateNormal(action_mean, cov_mat)

            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(joint_state)

        return action_logprobs, state_values, dist_entropy

# PPO class with separate buffers for each predator
class MAPPO:
    def __init__(self, state_dim, joint_state_dim, action_dim, policy_dim_loc, lr_actor, lr_critic, gamma, K_epochs, eps_clip,
                 has_continuous_action_space, device, agent_num=2, policy_rep=False, action_std_init=0.6):

        self.has_continuous_action_space = has_continuous_action_space
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.policy_rep = policy_rep
        self.policy_dim_loc = policy_dim_loc
        self.agent_num = agent_num

        # Separate buffers for each cooperative agent
        self.buffers = {str(i): RolloutBuffer() for i in range(agent_num)}
        self.policy_rep_buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, joint_state_dim, action_dim, has_continuous_action_space, action_std_init, device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic(state_dim, joint_state_dim, action_dim, has_continuous_action_space, action_std_init, device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()
        self.device = device

    def select_action(self, state, agent_id):
        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(self.device)
                action, action_logprob = self.policy_old.act(state)

            self.buffers[agent_id].states.append(state)
            self.buffers[agent_id].actions.append(action)
            self.buffers[agent_id].logprobs.append(action_logprob)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(self.device)
                action, action_logprob = self.policy_old.act(state)
            self.buffers[agent_id].states.append(state)
            self.buffers[agent_id].actions.append(action)
            self.buffers[agent_id].logprobs.append(action_logprob)

            return action.item()

    def update(self):
        # Merge all buffers into a single buffer
        combined_buffer = RolloutBuffer()
        for buffer in self.buffers.values():
            combined_buffer.merge(buffer)

        # The rest of the update logic remains the same as before Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(combined_buffer.rewards), reversed(combined_buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(combined_buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(torch.stack(combined_buffer.actions, dim=0)).detach().to(self.device)
        old_logprobs = torch.squeeze(torch.stack(combined_buffer.logprobs, dim=0)).detach().to(self.device)
        if self.policy_rep:
            self.policy_rep_tensor = torch.stack(self.policy_rep_buffer.policy_rep).squeeze(1).detach().to(self.device)
        
        # bulid joint state
        total_samples = len(old_states) // self.agent_num
        agent_states = torch.split(old_states, total_samples, dim=0)    
        if self.policy_rep:
            agent_states = [agent_state[:, :-self.policy_dim_loc] for agent_state in agent_states]
            joint_states = torch.cat(agent_states, dim=1)
            joint_states = torch.cat((joint_states, self.policy_rep_tensor), dim=1)
        else:
            joint_states = torch.cat(agent_states, dim=1)
        joint_states = joint_states.repeat(self.agent_num, 1)
        
        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions, joint_states)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # Clear all buffers after update
        for buffer in self.buffers.values():
            buffer.clear()
        self.policy_rep_buffer.clear()
            
        return loss.mean()

    def save_model(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)

    def clear_buffer(self):
        self.policy_rep_buffer.clear()
        for buffer in self.buffers.values():
            buffer.clear()

    def load_model(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

    def eval(self):
        self.policy.eval()
        self.policy_old.eval()
