import torch
import torch.nn as nn
from harl.models.base.cnn import CNNBase
from harl.models.base.mlp import MLPBase
from harl.models.base.rnn import RNNLayer
from harl.utils.envs_tools import check, get_shape_from_obs_space
from harl.utils.models_tools import init, get_init_method
from harl.models.base.hypermarl import HyperMLPLayer, HyperNetType, MLPBase as HyperMLPBase


class VNet(nn.Module):
    """V Network. Outputs value function predictions given global states."""

    def __init__(self, args, cent_obs_space, device=torch.device("cpu")):
        """Initialize VNet model.
        Args:
            args: (dict) arguments containing relevant model information.
            cent_obs_space: (gym.Space) centralized observation space.
            device: (torch.device) specifies the device to run on (cpu/gpu).
        """
        super(VNet, self).__init__()
        self.hidden_sizes = args["hidden_sizes"]
        self.initialization_method = args["initialization_method"]
        self.use_naive_recurrent_policy = args["use_naive_recurrent_policy"]
        self.use_recurrent_policy = args["use_recurrent_policy"]
        self.recurrent_n = args["recurrent_n"]
        self.tpdv = dict(dtype=torch.float32, device=device)
        init_method = get_init_method(self.initialization_method)

        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        self.use_hypermarl_critic = args.get("use_hypermarl_critic",False)
        self.num_agents = args["num_agents"]

        if len(cent_obs_shape) == 3:
            base = CNNBase
        else:
            if self.use_hypermarl_critic:
                print("Using hypermarl in critic")
                self.base = HyperMLPBase(args,cent_obs_shape,hypernet_type=HyperNetType.CRITIC)
            else:
                self.base = MLPBase(args, cent_obs_shape)
                
        # base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
        # self.base = base(args, cent_obs_shape)

        if self.use_naive_recurrent_policy or self.use_recurrent_policy:
            self.rnn = RNNLayer(
                self.hidden_sizes[-1],
                self.hidden_sizes[-1],
                self.recurrent_n,
                self.initialization_method,
            )

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))

        if self.use_hypermarl_critic:
            hypermarl = args["hypermarl"]
            num_agents = args["num_agents"]
            obs_dim = cent_obs_shape[0] - num_agents
            self.v_out = HyperMLPLayer(
                # input_dim=obs_dim,
                # hidden_sizes=[1],
                input_dim=self.hidden_sizes[-1],
                # hidden_sizes=self.hidden_sizes,
                hidden_sizes=[1],
                initialization_method=self.initialization_method,
                activation_func="identity", # i.e. no activation function
                num_agents=num_agents,
                embedding_size=hypermarl["AGENT_ID_EMBEDDING_DIM"],
                use_agent_id_embeddings=hypermarl["USE_AGENT_ID_EMBEDDINGS"],
                hypernet_hidden_dims=hypermarl["HYPERNET_HIDDEN_DIMS"],
                generate_per_agent=True,  # <--- Turn on the memory-friendly approach
                hypernet_type=HyperNetType.CRITIC,
                gain=1,
                use_layer_norm=False,
                use_mlp_hypernet = hypermarl.get("USE_MLP_HYPERNET", True)
            )
        else:
            self.v_out = init_(nn.Linear(self.hidden_sizes[-1], 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, masks):
        """Compute actions from the given inputs.
        Args:
            cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
            rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
            masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.
        Returns:
            values: (torch.Tensor) value function predictions.
            rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)

        critic_features = self.base(cent_obs)
        if self.use_naive_recurrent_policy or self.use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
            
        if self.use_hypermarl_critic:
            agent_ids_one_hot = cent_obs[:, -self.num_agents:]
            critic_features = torch.cat([critic_features, agent_ids_one_hot], dim=-1)
            values = self.v_out(critic_features)
        else:
            values = self.v_out(critic_features)

        return values, rnn_states
