import torch
import torch.nn as nn
import numpy as np



def mlp_(sizes, activation, output_activation=nn.Identity):
    """
    Creates a multi-layer perceptron with the specified sizes and activations.

    Args:
        sizes (list): A list of integers specifying the size of each layer in the MLP.
        activation (nn.Module): The activation function to use for all layers except the output layer.
        output_activation (nn.Module): The activation function to use for the output layer. Defaults to nn.Identity.

    Returns:
        nn.Sequential: A PyTorch Sequential model representing the MLP.
    """

    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layer = nn.Linear(sizes[j], sizes[j + 1])
        layers += [layer, act()]
    return nn.Sequential(*layers)


def mlp(sizes, activation, output_activation=nn.Identity, layernorm=True, dropout=0.0):
    """
    Creates a multi-layer perceptron with the specified sizes and activations,
    optionally adding LayerNorm and Dropout after each hidden layer.

    Args:
        sizes (list): Layer sizes.
        activation (nn.Module): Activation for hidden layers.
        output_activation (nn.Module): Activation for output layer.
        layernorm (bool): Whether to add LayerNorm after each hidden layer.
        dropout (float): Dropout probability after each activation.

    Returns:
        nn.Sequential: The constructed MLP.
    """
    layers = []
    for j in range(len(sizes) - 1):
        layer = nn.Linear(sizes[j], sizes[j + 1])
        if j < len(sizes) - 2:
            layer_ = [layer]
            if layernorm:
                layer_.append(nn.LayerNorm(sizes[j + 1]))
            layer_.append(activation())
            if dropout > 0.0:
                layer_.append(nn.Dropout(dropout))
        else:
            layer_ = [layer, output_activation()]
        layers += layer_
    return nn.Sequential(*layers)


class EnsembleQCritic(nn.Module):
    '''
    An ensemble of Q network to address the overestimation issue.

    Args:
        obs_dim (int): The dimension of the observation space.
        act_dim (int): The dimension of the action space.
        hidden_sizes (List[int]): The sizes of the hidden layers in the neural network.
        activation (Type[nn.Module]): The activation function to use between layers.
        num_q (float): The number of Q networks to include in the ensemble.
    '''

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, use_layer_norm=True, num_q=2):
        super().__init__()
        assert num_q >= 1, "num_q param should be greater than 1"

        self.q_nets = nn.ModuleList([
            mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation, layernorm=use_layer_norm)
            for i in range(num_q)
        ])

        self._init_weights()

    def _init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

    def forward(self, obs, act=None):
        # Squeeze is critical to ensure value has the right shape.
        # Without squeeze, the training stability will be greatly affected!
        # For instance, shape [3] - shape[3,1] = shape [3, 3] instead of shape [3]
        data = obs if act is None else torch.cat([obs, act], dim=-1)
        return [q(data) for q in self.q_nets]
        # return [torch.squeeze(torch.nn.functional.softplus(q(data)), -1) for q in self.q_nets]

    def predict(self, obs, act):
        q_list = self.forward(obs, act)
        qs = torch.stack(q_list)# [num_q, batch_size]
        return torch.min(qs, dim=0).values, q_list
        # return torch.mean(qs, dim=0), q_list


    def loss(self, target, q_list=None):
        losses = [torch.nn.functional.mse_loss(q, target) for q in q_list]
        return sum(losses)


class EnsembleDoubleQCritic(nn.Module):
    '''
    An ensemble of double Q network to address the overestimation issue.

    Args:
        obs_dim (int): The dimension of the observation space.
        act_dim (int): The dimension of the action space.
        hidden_sizes (List[int]): The sizes of the hidden layers in the neural network.
        activation (Type[nn.Module]): The activation function to use between layers.
        num_q (float): The number of Q networks to include in the ensemble.
    '''

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, num_q=2):
        super().__init__()
        assert num_q >= 1, "num_q param should be greater than 1"
        self.q1_nets = nn.ModuleList([
            mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], nn.ReLU)
            for i in range(num_q)
        ])
        self.q2_nets = nn.ModuleList([
            mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], nn.ReLU)
            for i in range(num_q)
        ])

    def forward(self, obs, act):
        # Squeeze is critical to ensure value has the right shape.
        # Without squeeze, the training stability will be greatly affected!
        # For instance, shape [3] - shape[3,1] = shape [3, 3] instead of shape [3]
        data = torch.cat([obs, act], dim=-1)
        q1 = [q(data) for q in self.q1_nets]
        q2 = [q(data) for q in self.q2_nets]
        return q1, q2

    def predict(self, obs, act):
        q1_list, q2_list = self.forward(obs, act)
        qs1, qs2 = torch.vstack(q1_list), torch.vstack(q2_list)
        # qs = torch.vstack(q_list)  # [num_q, batch_size]
        qs1_min, qs2_min = torch.min(qs1, dim=0).values, torch.min(qs2, dim=0).values
        return qs1_min, qs2_min, q1_list, q2_list

    def loss(self, target, q_list=None):
        losses = [((q - target)**2).mean() for q in q_list]
        return sum(losses)


class EnsembleValue(nn.Module):
    '''
    An ensemble of Value network to address the overestimation issue.

    Args:
        obs_dim (int): The dimension of the observation space.
        act_dim (int): The dimension of the action space.
        hidden_sizes (List[int]): The sizes of the hidden layers in the neural network.
        activation (Type[nn.Module]): The activation function to use between layers.
        num_v (float): The number of Value networks to include in the ensemble.
    '''

    def __init__(self, obs_dim, hidden_sizes, activation, use_layer_norm=True, num_v=2):
        super().__init__()
        assert num_v >= 1, "num_q param should be greater than 1"

        self.v_nets = nn.ModuleList([
            mlp([obs_dim] + list(hidden_sizes) + [1], activation, layernorm=use_layer_norm)
            for i in range(num_v)
        ])

        self._init_weights()

    def _init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

    def forward(self, obs):
        data = obs
        return [v(data) for v in self.v_nets]
        # return [torch.squeeze(torch.nn.functional.softplus(v(data)), -1) for v in self.v_nets]

    def predict(self, obs):
        v_list = self.forward(obs)
        vs = torch.stack(v_list)  # [num_q, batch_size]
        return torch.min(vs, dim=0).values, v_list
        # return torch.mean(qs, dim=0), q_list


    def loss(self, target, v_list=None):
        losses = [torch.nn.functional.mse_loss(v, target) for v in v_list]
        return sum(losses)


class EnsembleTauValue(nn.Module):
    '''
    An ensemble of Value network to address the overestimation issue.

    Args:
        obs_dim (int): The dimension of the observation space.
        act_dim (int): The dimension of the action space.
        hidden_sizes (List[int]): The sizes of the hidden layers in the neural network.
        activation (Type[nn.Module]): The activation function to use between layers.
        num_v (float): The number of Value networks to include in the ensemble.
    '''

    def __init__(self, obs_dim, hidden_sizes, activation, num_v=2):
        super().__init__()
        assert num_v >= 1, "num_q param should be greater than 1"

        self.v_nets = nn.ModuleList([
            mlp([obs_dim+64] + list(hidden_sizes) + [1], activation)
            for i in range(num_v)
        ])

        self.tau_mlp = nn.Sequential(
            nn.Linear(1, 256),
            nn.Mish(),
            nn.Linear(256, 64),
        )

        self._init_weights()

    def _init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

    def forward(self, obs, tau):
        tau_embed = self.tau_mlp(tau)  #
        data = torch.concat([obs, tau_embed], dim=-1)
        return [v(data) for v in self.v_nets]
        # return [torch.squeeze(torch.nn.functional.softplus(v(data)), -1) for v in self.v_nets]

    def predict(self, obs, tau):
        v_list = self.forward(obs, tau)
        vs = torch.stack(v_list)  # [num_q, batch_size]
        return torch.min(vs, dim=0).values, v_list
        # return torch.mean(qs, dim=0), q_list


    def loss(self, target, v_list=None):
        losses = [torch.nn.functional.mse_loss(v, target) for v in v_list]
        return sum(losses)


class EnsembleTauQCritic(nn.Module):
    '''
    An ensemble of Q network to address the overestimation issue.

    Args:
        obs_dim (int): The dimension of the observation space.
        act_dim (int): The dimension of the action space.
        hidden_sizes (List[int]): The sizes of the hidden layers in the neural network.
        activation (Type[nn.Module]): The activation function to use between layers.
        num_q (float): The number of Q networks to include in the ensemble.
    '''

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, num_q=2):
        super().__init__()
        assert num_q >= 1, "num_q param should be greater than 1"

        self.q_nets = nn.ModuleList([
            mlp([obs_dim + act_dim + 64] + list(hidden_sizes) + [1], activation)
            for i in range(num_q)
        ])

        self.tau_mlp = nn.Sequential(
            nn.Linear(1, 256),
            nn.Mish(),
            nn.Linear(256, 64),
        )

        self._init_weights()

    def _init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

    def forward(self, obs, act, tau):
        tau_embed = self.tau_mlp(tau)
        data = torch.cat([obs, act, tau_embed], dim=-1)
        return [q(data) for q in self.q_nets]
        # return [torch.squeeze(torch.nn.functional.softplus(q(data)), -1) for q in self.q_nets]

    def predict(self, obs, act, tau):
        q_list = self.forward(obs, act, tau)
        qs = torch.stack(q_list)# [num_q, batch_size]
        return torch.min(qs, dim=0).values, q_list
        # return torch.mean(qs, dim=0), q_list

    def loss(self, target, q_list=None):
        losses = [torch.nn.functional.mse_loss(q, target) for q in q_list]
        return sum(losses)
