import torch
import torch.nn as nn
from harl.utils.envs_tools import get_shape_from_obs_space
from harl.models.base.plain_cnn import PlainCNN
from harl.models.base.plain_mlp import PlainMLP
from harl.models.base.hypermarl import HyperNetType, MLPBase as HyperMLPBase


class DeterministicPolicy(nn.Module):
    """Deterministic policy network for continuous action space."""

    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        """Initialize DeterministicPolicy model.
        Args:
            args: (dict) arguments containing relevant model information.
            obs_space: (gym.Space) observation space.
            action_space: (gym.Space) action space.
            device: (torch.device) specifies the device to run on (cpu/gpu).
        """
        super().__init__()
        self.tpdv = dict(dtype=torch.float32, device=device)
        hidden_sizes = args["hidden_sizes"]
        activation_func = args["activation_func"]
        final_activation_func = args["final_activation_func"]
        obs_shape = get_shape_from_obs_space(obs_space)
        self.use_hypermarl = args.get("use_hypermarl_actor",False)
        if len(obs_shape) == 3:
            self.feature_extractor = PlainCNN(
                obs_shape, hidden_sizes[0], activation_func
            )
            feature_dim = hidden_sizes[0]
        else:
            self.feature_extractor = None
            feature_dim = obs_shape[0]
        act_dim = action_space.shape[0]
        pi_sizes = [feature_dim] + list(hidden_sizes) + [act_dim]

        if self.use_hypermarl:
            # print("Using hypermarl in critic")
            # print(args["hidden_sizes"],cent_obs_shape,args["num_agents"])
            # hidden - eveything except the input
            print("Using hypermarl in actor")
            print(pi_sizes)
            sizes = pi_sizes
            args["hidden_sizes"]=sizes[1:]
            self.pi = HyperMLPBase(args,obs_shape=[sizes[0]],hypernet_type=HyperNetType.ACTOR,final_activation_func=final_activation_func,use_layer_norm=True, generates_final_layer=True)
            print(self.pi)
            # exit()
        else:
            self.pi = PlainMLP(sizes=pi_sizes, activation_func=activation_func, final_activation_func=final_activation_func)
            
        # self.pi = PlainMLP(pi_sizes, activation_func, final_activation_func)
        low = torch.tensor(action_space.low).to(**self.tpdv)
        high = torch.tensor(action_space.high).to(**self.tpdv)
        self.scale = (high - low) / 2
        self.mean = (high + low) / 2
        self.to(device)

    def forward(self, obs):
        # Return output from network scaled to action space limits.
        if self.feature_extractor is not None:
            x = self.feature_extractor(obs)
        else:
            x = obs
        x = self.pi(x)
        x = self.scale * x + self.mean
        return x
    
# from https://github.com/LXXXXR/Kaleidoscope/blob/fa560a9400fa8c9fc8ad6af94b7d2418038060c4/Kalei_MaMuJoCo/src/harl/models/policy_models/deterministic_policy.py#L53
from harl.models.base.plain_mlp import Kalei_MLP
class Kalei_DeterministicPolicy(DeterministicPolicy):
    """Deterministic policy network for continuous action space with Kaleidoscope parameter sharing."""

    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        """Initialize DeterministicPolicy model.
        Args:
            args: (dict) arguments containing relevant model information.
            obs_space: (gym.Space) observation space.
            action_space: (gym.Space) action space.
            device: (torch.device) specifies the device to run on (cpu/gpu).
        """
        nn.Module.__init__(self)
        self.tpdv = dict(dtype=torch.float32, device=device)
        hidden_sizes = args["hidden_sizes"]
        activation_func = args["activation_func"]
        final_activation_func = args["final_activation_func"]
        obs_shape = get_shape_from_obs_space(obs_space)
        if len(obs_shape) == 3:
            self.feature_extractor = PlainCNN(
                obs_shape, hidden_sizes[0], activation_func
            )
            feature_dim = hidden_sizes[0]
        else:
            self.feature_extractor = None
            feature_dim = obs_shape[0]
        act_dim = action_space.shape[0]
        pi_sizes = [feature_dim] + list(hidden_sizes) + [act_dim]
        self.pi = Kalei_MLP(pi_sizes, activation_func, args, final_activation_func)
        low = torch.tensor(action_space.low).to(**self.tpdv)
        high = torch.tensor(action_space.high).to(**self.tpdv)
        self.scale = (high - low) / 2
        self.mean = (high + low) / 2
        self.to(device)

    def get_sparsities(self):
        return self.pi.get_sparsities()

    def _reset_all_masks_weights(self, reset_ratio):
        return self.pi._reset_all_masks_weights(reset_ratio)

    def mask_diversity_loss(self):
        return self.pi.mask_diversity_loss()

    def forward(self, obs, mask_id):
        # Return output from network scaled to action space limits.
        if self.feature_extractor is not None:
            x = self.feature_extractor(obs)
        else:
            x = obs
        x = self.pi(x, mask_id)
        x = self.scale * x + self.mean
        return x