from copy import deepcopy
import torch
import torch.nn as nn
from harl.models.base.plain_cnn import PlainCNN
from harl.models.base.plain_mlp import Kalei_MLP, PlainMLP
from harl.utils.envs_tools import get_shape_from_obs_space
from harl.models.base.hypermarl import HyperNetType, MLPBase as HyperMLPBase


def get_combined_dim(cent_obs_feature_dim, act_spaces):
    """Get the combined dimension of central observation and individual actions."""
    combined_dim = cent_obs_feature_dim
    for space in act_spaces:
        if space.__class__.__name__ == "Box":
            combined_dim += space.shape[0]
        elif space.__class__.__name__ == "Discrete":
            combined_dim += space.n
        else:
            action_dims = space.nvec
            for action_dim in action_dims:
                combined_dim += action_dim
    return combined_dim


class ContinuousQNet(nn.Module):
    """Q Network for continuous and discrete action space. Outputs the q value given global states and actions.
    Note that the name ContinuousQNet emphasizes its structure that takes observations and actions as input and outputs
    the q values. Thus, it is commonly used to handle continuous action space; meanwhile, it can also be used in
    discrete action space.
    """

    def __init__(self, args, cent_obs_space, act_spaces, device=torch.device("cpu")):
        super(ContinuousQNet, self).__init__()
        activation_func = args["activation_func"]
        hidden_sizes = args["hidden_sizes"]
        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        self.use_hypermarl = args.get("use_hypermarl_critic",False)
        self.num_agents = args.get("num_agents",1)
        if len(cent_obs_shape) == 3:
            self.feature_extractor = PlainCNN(
                cent_obs_shape, hidden_sizes[0], activation_func
            )
            cent_obs_feature_dim = hidden_sizes[0]
        else:
            self.feature_extractor = None
            cent_obs_feature_dim = cent_obs_shape[0]
        sizes = (
            [get_combined_dim(cent_obs_feature_dim, act_spaces)]
            + list(hidden_sizes)
            + [1]
        )
        if self.use_hypermarl:
            # print("Using hypermarl in critic")
            # print(args["hidden_sizes"],cent_obs_shape,args["num_agents"])
            # hidden - eveything except the input
            args["hidden_sizes"]=sizes[1:]
            mlp = HyperMLPBase(args,obs_shape=[sizes[0]],hypernet_type=HyperNetType.CRITIC,final_activation_func="identity",use_layer_norm=True, generates_final_layer=True)
            print(mlp)
        else:
            mlp = PlainMLP(sizes=sizes, activation_func=activation_func)
        self.mlp = mlp
        self.to(device)

    def forward(self, cent_obs, actions):
        if self.feature_extractor is not None:
            feature = self.feature_extractor(cent_obs)
        else:
            feature = cent_obs
        
        # we need to be careful, we need to make sure agent ids are at the end
        if self.use_hypermarl:
            obs = feature[:,:-self.num_agents]
            one_hot_agent_ids = feature[...,-self.num_agents:]
            concat_x = torch.cat([obs, actions, one_hot_agent_ids], dim=-1)
        else:
            concat_x = torch.cat([feature, actions], dim=-1)
        q_values = self.mlp(concat_x)
        return q_values

# from https://github.com/LXXXXR/Kaleidoscope/blob/fa560a9400fa8c9fc8ad6af94b7d2418038060c4/Kalei_MaMuJoCo/src/harl/models/value_function_models/continuous_q_net.py#L64
class KaleiContinuousQNet(ContinuousQNet):
    def __init__(self, args, cent_obs_space, act_spaces, device=torch.device("cpu")):
        super(ContinuousQNet, self).__init__()
        activation_func = args["activation_func"]
        hidden_sizes = args["hidden_sizes"]
        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        if len(cent_obs_shape) == 3:
            raise NotImplementedError
        else:
            self.feature_extractor = None
            cent_obs_feature_dim = cent_obs_shape[0]
        sizes = (
            [get_combined_dim(cent_obs_feature_dim, act_spaces)]
            + list(hidden_sizes)
            + [1]
        )
        temp_args = deepcopy(args)
        # bad design, need to be refactored
        temp_args["Kalei_args"] = args["ensemble_args"]["critic_Kalei_args"]
        # one mask for each ensemble member
        temp_args["Kalei_args"]["n_masks"] = args["n_critics"]
        self.mlp = Kalei_MLP(sizes, activation_func, temp_args)
        self.to(device)

    def forward(self, cent_obs, actions, mask_id):
        if self.feature_extractor is not None:
            feature = self.feature_extractor(cent_obs)
        else:
            feature = cent_obs
        concat_x = torch.cat([feature, actions], dim=-1)
        q_values = self.mlp(concat_x, mask_id)
        return q_values

    def mask_diversity_loss(self):
        return self.mlp.mask_diversity_loss()

    def reset_mask(self, mask_id):
        self.mlp._reset_mask(mask_id)