import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import DataLoader, TensorDataset

################################## PPO Policy ##################################
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]


class ActorCritic(nn.Module):
    def __init__(self, device, state_dim_actor, state_dim_critic, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.device = device

        self.has_continuous_action_space = has_continuous_action_space
        
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(self.device)
        # actor
        if has_continuous_action_space :
            self.actor = nn.Sequential(
                            nn.BatchNorm1d(state_dim_actor),
                            nn.Linear(state_dim_actor, 500),
                            nn.Tanh(),
                            nn.Linear(500, 450),
                            nn.Tanh(),
                            nn.Linear(450, action_dim),
                            nn.Tanh()
                        )
            # with torch.no_grad():
            #     self.actor[-2].weight.data *= 0.01
            #     self.actor[-2].bias.data *= 0.01
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim_actor, 512),
                            nn.Tanh(),
                            nn.Linear(512, 450),
                            nn.Tanh(),
                            nn.Linear(450, 128),
                            nn.Tanh(),
                            nn.Linear(128, action_dim),
                            nn.Softmax(dim=-1)
                        )
        # critic
        self.critic = nn.Sequential(
                        nn.BatchNorm1d(state_dim_critic),
                        nn.Linear(state_dim_critic, 500),
                        nn.Tanh(),
                        nn.Linear(500, 450),
                        nn.Tanh(),
                        nn.Linear(450, 1)
                    )
        
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(self.device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def forward(self):
        raise NotImplementedError
    
    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        # state_val = self.critic(state)

        return action.detach(), action_logprob.detach()
    
    def evaluate(self, state, action):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(self.device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

class PPO:
    def __init__(self, device, state_dim_actor, state_dim_critic, action_dim, lr_actor, lr_critic, gae_lambda, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):

        self.device = device

        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gae_lambda = gae_lambda
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(self.device, state_dim_actor, state_dim_critic, action_dim, has_continuous_action_space, action_std_init).to(self.device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor, 'weight_decay': 0.01},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic, 'weight_decay': 0.01}
                    ])

        self.policy_old = ActorCritic(self.device, state_dim_actor, state_dim_critic, action_dim, has_continuous_action_space, action_std_init).to(self.device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.policy.eval()
        self.policy_old.eval()
        
        self.MseLoss = nn.MSELoss()

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        # print("--------------------------------------------------------------------------------------------")
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
            #     print("setting actor output action_std to min_action_std : ", self.action_std)
            # else:
            #     print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")
        # print("--------------------------------------------------------------------------------------------")

    def select_action(self, state):

        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                action, action_logprob = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            # self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(self.device)
                action, action_logprob = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            # self.buffer.state_values.append(state_val)

            return action.item()

    def update_gae(self, satellites_ppo):

        for ppo in satellites_ppo:
            self.policy.train()
            self.policy_old.train()

        # Extract stored data from the buffer
        rewards = self.buffer.rewards
        is_terminals = self.buffer.is_terminals

        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(self.device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(self.device)

        for _ in range(self.K_epochs):

            # Compute advantages and returns using GAE
            old_state_values = self.policy.critic(old_states)
            advantages, returns = self.compute_gae(rewards, old_state_values, is_terminals)

            # Create a DataLoader for batching
            dataset = TensorDataset(old_states, old_actions, old_logprobs, advantages, returns)
            batch_size = 64  # Define a batch size
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

            for batch in loader:
                # Unpack the batch
                batch_states, batch_actions, batch_logprobs, batch_advantages, batch_returns = batch

                # Evaluate old actions and values
                logprobs, state_values, dist_entropy = self.policy.evaluate(batch_states, batch_actions)

                # Match state_values tensor dimensions with batch_returns tensor
                state_values = torch.squeeze(state_values)

                # Calculate the ratio (pi_theta / pi_theta__old)
                ratios = torch.exp(logprobs - batch_logprobs.detach())

                # Calculate Surrogate Loss
                surr1 = ratios * batch_advantages
                surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * batch_advantages

                # Final loss (PPO clipped objective)
                loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, batch_returns) - 0.01 * dist_entropy

                # Take a gradient step
                self.optimizer.zero_grad()
                loss.mean().backward()
                self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # Clear the buffer
        self.buffer.clear()

        self.policy.eval()
        self.policy_old.eval()

    def compute_gae(self, rewards, values, is_terminals):
        """
        Compute Generalized Advantage Estimation (GAE).
        Args:
            rewards: List of rewards for each timestep.
            values: List of value function estimates for each timestep.
            is_terminals: List of booleans indicating terminal states.
        Returns:
            advantages: Computed GAE advantages.
            returns: Computed returns for each timestep.
        """
        advantages = []
        gae = 0
        returns = []
        next_value = 0  # Value after the end of the episode
        
        # Iterate over steps in reverse
        for step in reversed(range(len(rewards))):
            if is_terminals[step]:
                next_value = 0  # Reset at terminal state
                gae = 0
            
            delta = rewards[step] + self.gamma * next_value - values[step]
            gae = delta + self.gamma * self.gae_lambda * gae
            advantages.insert(0, gae)
            returns.insert(0, gae + values[step])  # Return = Advantage + Value
            
            next_value = values[step]

        # Convert to tensors
        advantages = torch.tensor(advantages, dtype=torch.float32).to(self.device)
        returns = torch.tensor(returns, dtype=torch.float32).to(self.device)

        # Normalize advantages for better optimization stability
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7)

        return advantages, returns

    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
        self.policy_old.load_state_dict(checkpoint)
        self.policy.load_state_dict(checkpoint)
        self.policy_old.to(self.device)
        self.policy.to(self.device)