import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.autograd import grad
from mat.utils.util import get_gard_norm, huber_loss, mse_loss, quantile_huber_loss
from mat.utils.valuenorm import ValueNorm
from mat.algorithms.utils.util import check, average_agent_encoders_by_adj, average_attention_params_by_adj

class MATTrainer:
    """
    Trainer class for MAT to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (R_MAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 num_agents,
                 device=torch.device("cpu")):

        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy        
        self.num_agents = num_agents
        self.num_quants = args.n_quants
        self.n_embd = args.n_embd
        self.truelyDistributed = args.truelyDistributed
        self.avg_critic = args.avg_critic

        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.mini_batch_size = args.mini_batch_size
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta
        # self.gnn_loss_coef = args.gnn_loss_coef
        self.gnn_loss_coef = torch.tensor(args.gnn_loss_coef, dtype=torch.float32, device="cuda:0")

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_valuenorm = args.use_valuenorm
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.dec_actor = args.dec_actor
        self.detach = args.detach

        self.eye = torch.eye(self.num_agents, device="cuda:0").unsqueeze(0)
        
        if self._use_valuenorm:
            self.value_normalizer = ValueNorm(self.num_agents, self.num_quants, norm_axes=0, device=self.device)
        else:
            self.value_normalizer = None

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """

        value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                    self.clip_param)

        if self._use_valuenorm:
            self.value_normalizer.update(return_batch)
            error_clipped = self.value_normalizer.normalize(return_batch) - value_pred_clipped
            error_original = self.value_normalizer.normalize(return_batch) - values
        else:
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values
            
        if self.num_quants == 1:
            #print('Loss using MSE')
            if self._use_huber_loss:
                value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
                value_loss_original = huber_loss(error_original, self.huber_delta)
            else:
                value_loss_clipped = mse_loss(error_clipped)
                value_loss_original = mse_loss(error_original)
        else:
            #print('Loss using quantile_huber_loss')
            if self._use_valuenorm:
                value_loss_clipped = quantile_huber_loss(self.value_normalizer.normalize(return_batch), value_pred_clipped)
                value_loss_original = quantile_huber_loss(self.value_normalizer.normalize(return_batch), values)
            else:
                value_loss_clipped = quantile_huber_loss(return_batch, value_pred_clipped)
                value_loss_original = quantile_huber_loss(return_batch, values)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            # Calculate per-agent policy loss with active masks
            value_loss = (value_loss * active_masks_batch).sum(dim=0) / active_masks_batch.sum(dim=0)
        else:
            # Calculate per-agent policy loss without active masks
            value_loss = value_loss.mean(dim=0)

        return value_loss


    def gnn_consensus_loss(self, x, edge_index):
        total_consensus_loss = torch.zeros(self.num_agents, dtype=torch.float32, device="cuda:0")
                
        for i in range(self.num_agents):
            # Create mask for valid edges (where edge_index[:, 1, :] != -1)
            valid_mask = (edge_index[:, 1, :] == i) & (edge_index[:, 1, :] != -1)  # [batch_size, max_edges]
            
            # Get valid source and target indices
            source_indices = edge_index[:, 0, :][valid_mask].long()  # [total_valid_edges]
            edge_indices = edge_index[:, 1, :][valid_mask].long()    # [total_valid_edges]

            # If no valid edges, return zero
            if source_indices.numel() == 0:
                return torch.tensor(0.0, device=edge_index.device)

            # Gather source and edge observations
            source_x = x[:, source_indices, :]
            if self.detach:
                edge_x = x[:, edge_indices, :].detach()
            else:
                edge_x = x[:, edge_indices, :]
            
            # Compute MSE loss
            total_consensus_loss[i] = F.mse_loss(source_x, edge_x, reduction='mean')
        
        return total_consensus_loss.unsqueeze(-1)

    def adj_gnn_consensus_loss(self, x, adj):
        """
        x:   [batch_size, n_agents, feature_dim]
        adj: [batch_size, n_agents, n_agents] (0/1 adjacency matrix)
        returns: [n_agents, 1] consensus loss per agent
        """
        # update adjcency matrix to incude "self"
        adj += self.eye

        # Expand for pairwise differences
        x_i = x.unsqueeze(2)  # [batch_size, n_agents, 1, feature_dim]
        x_j = x.unsqueeze(1)  # [batch_size, 1, n_agents, feature_dim]

        if self.detach:
            x_j = x_j.detach()

        # Compute squared L2 distance between embeddings
        diff = (x_i - x_j) ** 2                     # [batch_size, n_agents, n_agents, feature_dim]
        diff = diff.mean(dim=-1)                     # [batch_size, n_agents, n_agents]

        # Mask with adjacency
        masked_diff = diff * adj                    # [batch_size, n_agents, n_agents]

        # Sum over neighbors for each agent
        per_agent_loss = masked_diff.mean(dim=-1)    # [batch_size, n_agents]

        # Normalize by number of neighbors (avoid div/0)
        neighbor_counts = adj.sum(dim=-1)           # [batch_size, n_agents]
        per_agent_loss = per_agent_loss / (neighbor_counts + 1e-8)

        # Average across batch
        per_agent_loss = per_agent_loss.mean(dim=0) # [n_agents]

        return per_agent_loss.unsqueeze(-1)         # [n_agents, 1]


    def ppo_update(self, sample, episode, iter_step, obs_dim=None):
        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
        adv_targ, available_actions_batch, adjcency_matrix_batch, next_obs_batch, edge_index_batch = sample

        # Convert all inputs to proper device and dtype
        def ensure_tensor(x):
            if isinstance(x, np.ndarray):
                return torch.from_numpy(x).to(**self.tpdv)
            return x.to(**self.tpdv) if isinstance(x, torch.Tensor) else x

        old_action_log_probs_batch = ensure_tensor(old_action_log_probs_batch)
        adv_targ = ensure_tensor(adv_targ)
        value_preds_batch = ensure_tensor(value_preds_batch)
        return_batch = ensure_tensor(return_batch)
        active_masks_batch = ensure_tensor(active_masks_batch)
        obs_batch = ensure_tensor(obs_batch)
        edge_index_batch = ensure_tensor(edge_index_batch)
        adjcency_matrix_batch = ensure_tensor(adjcency_matrix_batch)
        next_obs_batch = ensure_tensor(next_obs_batch)
        share_obs_batch = ensure_tensor(share_obs_batch)
        rnn_states_batch = ensure_tensor(rnn_states_batch)
        rnn_states_critic_batch = ensure_tensor(rnn_states_critic_batch)
        actions_batch = ensure_tensor(actions_batch)
        masks_batch = ensure_tensor(masks_batch)
        if available_actions_batch is not None:
            available_actions_batch = ensure_tensor(available_actions_batch)

        # Reshape tensors to [batch_size, num_agents, ...]
        old_action_log_probs_batch = old_action_log_probs_batch.view(-1, self.num_agents, old_action_log_probs_batch.shape[-1])
        adv_targ = adv_targ.view(-1, self.num_agents, adv_targ.shape[-1])
        value_preds_batch = value_preds_batch.view(-1, self.num_agents, value_preds_batch.shape[-1])
        return_batch = return_batch.view(-1, self.num_agents, return_batch.shape[-1])
        active_masks_batch = active_masks_batch.view(-1, self.num_agents, active_masks_batch.shape[-1])

        obs_batch = obs_batch.view(-1, self.num_agents, obs_batch.shape[-1])
        edge_index_batch = edge_index_batch.view(-1, 2, edge_index_batch.shape[-1])
        next_obs_batch = next_obs_batch.view(-1, self.num_agents, next_obs_batch.shape[-1])
        share_obs_batch = share_obs_batch.view(-1, self.num_agents, share_obs_batch.shape[-1])
        rnn_states_batch = rnn_states_batch.view(-1, self.num_agents, rnn_states_batch.shape[-1])
        actions_batch = actions_batch.view(-1, self.num_agents, actions_batch.shape[-1])
        adjcency_matrix_batch = adjcency_matrix_batch.view(-1, self.num_agents, adjcency_matrix_batch.shape[-1])
        masks_batch = masks_batch.view(-1, self.num_agents, masks_batch.shape[-1])
        available_actions_batch = available_actions_batch.view(-1, self.num_agents, available_actions_batch.shape[-1])
        rnn_states_critic_batch = rnn_states_critic_batch.view(-1, self.num_agents, rnn_states_critic_batch.shape[-1])

        # Evaluate actions for this agent only
        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(
                share_obs_batch,
                obs_batch,
                rnn_states_batch,
                rnn_states_critic_batch,
                actions_batch,
                masks_batch,
                available_actions_batch,
                active_masks_batch,
            )

        # Calculate policy loss for this agent
        imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)
        surr1 = imp_weights * adv_targ
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ

        if self._use_policy_active_masks:
            min_surr = torch.min(surr1, surr2)
            masked_min_surr = min_surr * active_masks_batch
            policy_loss = -masked_min_surr.sum(dim=0) / (active_masks_batch.sum(dim=0) + 1e-8)
        else:
            policy_loss = -torch.min(surr1, surr2).mean(dim=0)

        # Calculate value loss for this agent
        value_loss = self.cal_value_loss(
                values, 
                value_preds_batch, 
                return_batch, 
                active_masks_batch
            )
        
        gnn_consensus_loss = self.adj_gnn_consensus_loss(obs_batch[:,:,self.policy.obs_dim:], adjcency_matrix_batch)
        loss = policy_loss - dist_entropy * self.entropy_coef + value_loss * self.value_loss_coef + self.gnn_loss_coef * gnn_consensus_loss 

        # We'll store losses for logging
        policy_losses = []
        value_losses = []

        total_value_loss = 0
        total_policy_loss = 0
        total_gnn_loss = 0

        if self.truelyDistributed:
            # Zero all gradients first
            for optimizer in self.policy.optimizers:
                optimizer.zero_grad()
            
            # Pre-compute all gradients in parallel
            all_grads = []
            for i in range(self.num_agents):
                grads = torch.autograd.grad(
                    outputs=loss[i],
                    inputs=self.policy.agent_parameters(i),
                    retain_graph=True,  # Need to keep graph for parallel
                    create_graph=False,
                    allow_unused=True
                )
                all_grads.append(grads)
            
            # Apply gradients and step optimizers
            for i, grads in enumerate(all_grads):
                for param, grad in zip(self.policy.agent_parameters(i), grads):
                    if grad is not None:
                        param.grad = grad
                if self._use_max_grad_norm:
                    grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.agent_parameters(i), self.max_grad_norm)
                self.policy.optimizers[i].step()

        else:
            # Total loss for this agent
            loss = loss.mean()

            # Zero gradients, backward pass, and optimizer step
            self.policy.optimizer.zero_grad()
            loss.backward()

            if self._use_max_grad_norm:
                grad_norm = nn.utils.clip_grad_norm_(
                    self.policy.transformer.parameters(), 
                    self.max_grad_norm
                )
            else:
                grad_norm = get_gard_norm(self.policy.transformer.parameters())

            self.policy.optimizer.step()

        for _ in range(1):
            if self.policy.algorithm_name == 'mappo_dgnn_dsgd':
                average_agent_encoders_by_adj(self.policy.transformer.obs_encoder.agent_encoders, adjcency_matrix_batch[0])
                average_agent_encoders_by_adj(self.policy.transformer.obs_encoder.node_classifier_heads, adjcency_matrix_batch[0])
                average_attention_params_by_adj(self.policy.transformer.obs_encoder.atts, adjcency_matrix_batch[0])

        # Return average losses across agents
        avg_value_loss = sum(value_loss) / len(value_loss)
        avg_policy_loss = sum(policy_loss) / len(policy_loss)
        avg_gnn_consensus_loss = sum(gnn_consensus_loss) / len(gnn_consensus_loss)
        avg_grad_norm = grad_norm #sum(grad_norms) / len(grad_norms)
     
        return avg_value_loss, avg_grad_norm, avg_policy_loss, dist_entropy.mean().item(), avg_grad_norm, imp_weights, avg_gnn_consensus_loss #, avg_value_consensus_loss

    def train(self, buffer, episode, obs_dim=None):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        advantages_copy = buffer.advantages.clone()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = torch.nan

        valid_advantages = advantages_copy[~torch.isnan(advantages_copy)]
        mean_advantages = valid_advantages.mean()
        std_advantages = valid_advantages.std(unbiased=False)

        advantages = (buffer.advantages - mean_advantages) / (std_advantages + 1e-5)

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0
        train_info['gnn_consensus_loss'] = 0

        for i in range(self.ppo_epoch):
            data_generator = buffer.feed_forward_generator_transformer(advantages, self.num_mini_batch, mini_batch_size=self.mini_batch_size)

            for sample in data_generator:

                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights, avg_gnn_consensus_loss \
                    = self.ppo_update(sample, episode, i, obs_dim)

                train_info['value_loss'] += value_loss
                train_info['policy_loss'] += policy_loss
                train_info['dist_entropy'] += dist_entropy
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()
                train_info['gnn_consensus_loss'] += avg_gnn_consensus_loss

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates
 
        return train_info

    def prep_training(self):
        self.policy.train()

    def prep_rollout(self):
        self.policy.eval()
