#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  This source code is licensed under the license found in the
#  LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING
from typing import Dict, Iterable, Tuple, Type

import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical
from torchrl.data import Composite, Unbounded
from torchrl.modules import (
    IndependentNormal,
    MaskedCategorical,
    ProbabilisticActor,
    TanhNormal,
)
from torchrl.objectives import ClipPPOLoss, LossModule, ValueEstimators

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig
from benchmarl.algorithms.mappo import Mappo
from benchmarl.models.common import ModelConfig
from benchmarl.modules.cermic_module import CermicModule, reshape_batch_agents, reshape_agents_to_batch


class MappoCermic(Mappo):
    """Multi Agent PPO with Cermic Module.

    Args:
        share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
        clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
        entropy_coef (scalar): entropy multiplier when computing the total loss.
        critic_coef (scalar): critic loss multiplier when computing the total
        loss_critic_type (str): loss function for the value discrepancy.
            Can be one of "l1", "l2" or "smooth_l1".
        lmbda (float): The GAE lambda
        scale_mapping (str): positive mapping function to be used with the std.
            choices: "softplus", "exp", "relu", "biased_softplus_1";
        use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound
            to the action domain. Otherwise, an IndependentNormal is used.
        minibatch_advantage (bool): if ``True``, advantage computation is perfomerd on minibatches of size
            ``experiment.config.on_policy_minibatch_size`` instead of the full
            ``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage
        cermic_tau (float): The momentum network update rate for Cermic Module
        cermic_loss_var_weight (float): Weight for KL divergence loss in Cermic Module
        cermic_loss_l2_weight (float): Weight for L2 reconstruction loss in Cermic Module
        cermic_loss_nce_weight (float): Weight for NCE contrastive loss in Cermic Module
        cermic_aug (bool): Whether to use data augmentation in Cermic Module
        cermic_intrinsic_reward_coef (float): Coefficient for intrinsic reward
    """

    def __init__(
        self,
        share_param_critic: bool,
        clip_epsilon: float,
        entropy_coef: bool,
        critic_coef: float,
        loss_critic_type: str,
        lmbda: float,
        scale_mapping: str,
        use_tanh_normal: bool,
        minibatch_advantage: bool,
        cermic_tau: float,
        cermic_loss_var_weight: float,
        cermic_loss_l2_weight: float,
        cermic_loss_nce_weight: float,
        cermic_aug: bool,
        cermic_intrinsic_reward_coef: float,
        **kwargs
    ):
        super().__init__(
            share_param_critic=share_param_critic,
            clip_epsilon=clip_epsilon,
            entropy_coef=entropy_coef,
            critic_coef=critic_coef,
            loss_critic_type=loss_critic_type,
            lmbda=lmbda,
            scale_mapping=scale_mapping,
            use_tanh_normal=use_tanh_normal,
            minibatch_advantage=minibatch_advantage,
            **kwargs
        )
        
        self.cermic_tau = cermic_tau
        self.cermic_loss_var_weight = cermic_loss_var_weight
        self.cermic_loss_l2_weight = cermic_loss_l2_weight
        self.cermic_loss_nce_weight = cermic_loss_nce_weight
        self.cermic_aug = cermic_aug
        self.cermic_intrinsic_reward_coef = cermic_intrinsic_reward_coef
        
        self.cermic_modules = {}  # Create a Cermic Module for each agent group
        self.agents_per_group = {}  # Store the number of agents in each group
     
    def _get_loss(
        self, group: str, policy_for_loss: TensorDictModule, continuous: bool
    ) -> Tuple[LossModule, bool]:
        # Create original PPO loss
        loss_module, need_state = super()._get_loss(group, policy_for_loss, continuous)
        # Create Cermic Module for this group
        action_dim = self.action_spec[group, "action"].space.n if not continuous else self.action_spec[group, "action"].shape[-1]
        # Get observation shape
        obs_shape = self.observation_spec[group]['observation'].shape[1:]  # Remove agent dimension

        # Create Cermic Module
        self.cermic_modules[group] = CermicModule(
            tau=self.cermic_tau,
            aug=self.cermic_aug,
            observation_shape=obs_shape,
            action_dim=action_dim
        ).to(self.device)
        
        return loss_module, need_state

    def _get_parameters(self, group: str, loss: ClipPPOLoss) -> Dict[str, Iterable]:
        params = super()._get_parameters(group, loss)

        params["loss_cermic_module"] = self.cermic_modules[group].parameters()
        
        return params

    def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
        # print(f"Batch keys: {batch.keys(True, True)}")

        cermic = self.cermic_modules[group]
        
        obs = batch.get((group, "observation"))
        next_obs = batch.get(("next", group, "observation"))
        next_obs = next_obs[:, 0].unsqueeze(1)
        actions = batch.get((group, "action"))
        extrinsic_reward = batch.get(("next", group, "reward")) # torch.Size([10, 600, 4, 1])
        
        # Extract position information
        intrinsic_reward = cermic.calculate_cermic_reward(
            obs, 
            next_obs, 
            actions, 
            prev_extrinsic_reward=extrinsic_reward
        )
                        
        int_reward = intrinsic_reward * self.cermic_intrinsic_reward_coef

        total_reward = extrinsic_reward + int_reward
        # print("Extrinsic Reward:", extrinsic_reward.mean(), "Intrinsic Reward:", int_reward.mean(), "Total Reward:", total_reward.mean())

        batch.set(("next", group, "reward"), total_reward)
        batch.set(("next", group, "extrinsic_reward"), extrinsic_reward)
        batch.set(("next", group, "intrinsic_reward"), int_reward)
        
        batch = super().process_batch(group, batch)
                
        return batch

    def _extract_position_info(self, batch, group):
        """Extract position information"""
        position_info = {}

        # Use dictionary to store position-related information, simplify code
        position_keys = {
            "agent_position": (group, "position"),
            "other_positions": (group, "other_positions"),
            "vision_range": (group, "vision_range")
        }
        
        # Extract position information from batch
        for key, tensor_key in position_keys.items():
            if tensor_key in batch.keys(include_nested=True):
                position_info[key] = batch.get(tensor_key)
        
        # Set default vision range
        if "vision_range" not in position_info:
            if hasattr(self.action_spec[group, "action"], "vision_range"):
                position_info["vision_range"] = self.action_spec[group, "action"].vision_range
            else:
                position_info["vision_range"] = 5.0

        return position_info if position_info else None

    def update(self, group, group_obs: torch.Tensor, group_next_obs: torch.Tensor, group_ac: torch.Tensor, group_reward: torch.Tensor, group_position: torch.Tensor) -> Dict[str, float]:
        update_info = {}

        cermic = self.cermic_modules[group]
        
        obs = group_obs
        num_agents = obs.shape[2]
        next_obs = group_next_obs
        next_obs = next_obs[:, 0].unsqueeze(1)
        actions = group_ac

        # Use reshape_batch_agents function to process all tensors
        obs = reshape_batch_agents(obs)
        next_obs = reshape_batch_agents(next_obs)
        actions = reshape_batch_agents(actions)
        positions = reshape_batch_agents(group_position)
        rewards = reshape_batch_agents(group_reward)
        
        use_calibration = {
            "position": positions,
            "reward": rewards,
            "num_agents": num_agents
        }

        cermic_loss = 0
        cermic_loss = cermic.forward(
            obs, 
            next_obs, 
            actions,
            use_calibration=use_calibration,
        )
        # print("CERMIC Loss:", cermic_loss.item())
        # Use optimizer from experiment
        optimizer = self.experiment.optimizers[group]["loss_cermic_module"]
        optimizer.zero_grad()
        cermic_loss.backward()
        optimizer.step()
        
        cermic.momentum_update()
        
        # Add CERMIC loss information to update info
        for k, v in cermic.loss_info.items():
            if k in ["CERMIC_CalibrationUB", "CERMIC_CalibrationLB"] and isinstance(v, torch.Tensor):
                update_info[f"{group}/{k}"] = v.item()
            else:
                update_info[f"{group}/{k}"] = v
        
        update_info[f"{group}/CERMIC_Loss_Total"] = cermic_loss.item()
        
        return update_info


@dataclass
class MappoCermicConfig(AlgorithmConfig):
    """Configuration dataclass for :class:`~benchmarl.algorithms.MappoCermic`."""

    share_param_critic: bool = MISSING
    clip_epsilon: float = MISSING
    entropy_coef: float = MISSING
    critic_coef: float = MISSING
    loss_critic_type: str = MISSING
    lmbda: float = MISSING
    scale_mapping: str = MISSING
    use_tanh_normal: bool = MISSING
    minibatch_advantage: bool = MISSING
    cermic_tau: float = MISSING
    cermic_loss_var_weight: float = MISSING
    cermic_loss_l2_weight: float = MISSING
    cermic_loss_nce_weight: float = MISSING
    cermic_aug: bool = MISSING
    cermic_intrinsic_reward_coef: float = MISSING

    @staticmethod
    def associated_class() -> Type[Algorithm]:
        return MappoCermic

    @staticmethod
    def supports_continuous_actions() -> bool:
        return True

    @staticmethod
    def supports_discrete_actions() -> bool:
        return True

    @staticmethod
    def on_policy() -> bool:
        return True

    @staticmethod
    def has_centralized_critic() -> bool:
        return True 