import torch
from torch import nn
from src import torch_helpers
import math

class MLPParallelEnsembleLayer(torch.nn.Module):
    def __init__(self, ensemble_size, in_features, out_features, bias=True):
        super(MLPParallelEnsembleLayer, self).__init__()

        self.ensemble_size = ensemble_size
        self.in_features = in_features
        self.out_features = out_features

        self.weight = torch.nn.parameter.Parameter(
            torch.Tensor(ensemble_size, in_features, out_features)
        )
        if bias:
            self.bias = torch.nn.parameter.Parameter(
                torch.Tensor(ensemble_size, out_features)
            )
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        torch_helpers.torch_truncated_normal_initializer(self.weight)
        if self.bias is not None:
            torch.nn.init.constant_(self.bias, 0.0)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        We need it that way, because otherwise the tensorRT model doesn't work.
        E.g. it expects a certain shape of the bias vector
        """

        return ((input @ self.weight).transpose(0, 1) + self.bias[None, ...]).transpose(
            0, 1
        )

    def extra_repr(self) -> str:
        return "ensemble_size={}, in_features={}, out_features={}, bias={}".format(
            self.ensemble_size,
            self.in_features,
            self.out_features,
            self.bias is not None,
        )


class MLPParallelEnsemble(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        n,
        hidden_shape,
        activation="relu",
        output_activation="none",
        l1_reg=0.0,
        l2_reg=0.0,
        weight_initializer="bug_uniform",
        bias_initializer="big_uniform",
        use_spectral_normalization=False,
    ):
        """ "
        :models should be a list of torch.Sequential
        """
        super().__init__()

        self.n = n
        self.hidden_shape = hidden_shape
        self.activation = torch_helpers.activation_from_string(activation)
        self.output_activation = torch_helpers.activation_from_string(output_activation)
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer
        self.use_spectral_normalization = use_spectral_normalization

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.build()
        self.initialize()

    def initialize(self):
        for layer in self.layers:
            if isinstance(layer, MLPParallelEnsembleLayer):
                layer.apply(
                    torch_helpers.initializer_from_string(
                        self.weight_initializer, self.bias_initializer
                    )
                )
            else:
                # then it's a combo of layer and activation
                layer[0].apply(
                    torch_helpers.initializer_from_string(
                        self.weight_initializer, self.bias_initializer
                    )
                )
        pass

    def build(self):

        all_dims = [self.input_dim] + self.hidden_shape
        hidden_layers = []

        for in_dim, out_dim in zip(all_dims[:-1], all_dims[1:]):
            layer = [MLPParallelEnsembleLayer(self.n, in_dim, out_dim)]
            if self.use_spectral_normalization:
                raise Exception("Not ported from mbrl yet")
                # layer[-1] = ensemble_spectral_norm(layer[-1])
            if self.activation is not None:
                layer.append(self.activation())
            hidden_layers.append(nn.Sequential(*layer))

        layer = [MLPParallelEnsembleLayer(self.n, all_dims[-1], self.output_dim)]
        if self.output_activation is not None:
            layer.append(self.output_activation())

        layers = hidden_layers + layer

        self.layers_list = torch.nn.ModuleList(layers)
        self.layers = nn.Sequential(*self.layers_list)

    def to(self, device):
        self.layers.to(device)
        return super().to(device)

    def forward(self, x):
        return self.layers(x)

    def L1_losses(self):
        l1_losses = 0
        for layer in self.layers:
            l1_losses += torch.sum(layer[0].weight.norm(1, dim=(1, 2)))

        return l1_losses * self.l1_reg

    def L2_losses(self):
        l2_losses = 0
        for layer in self.layers:
            l2_losses += torch.sum(layer[0].weight.norm(2, dim=(1, 2)))

        return l2_losses * self.l2_reg


class EnsembleOfDiscriminators(nn.Module):
    def __init__(
        self,
        num_classes,
        num_inputs,
        n,
        hidden_shape,
        activation="relu",
        output_activation="none",
        l1_reg=0.0,
        l2_reg=0.0,
        weight_initializer="uniform",
        bias_initializer="uniform",
        use_spectral_normalization=False,
        w_disdain=0.0,
        entropy_penalty=1.0
    ) -> None:
        super().__init__()

        self.w_disdain = w_disdain
        
        self.mlp_ensemble = MLPParallelEnsemble(
            input_dim=num_inputs,
            output_dim=num_classes,
            n=n,
            hidden_shape=hidden_shape,
            activation=activation,
            output_activation=output_activation,
            l1_reg=l1_reg,
            l2_reg=l2_reg,
            weight_initializer=weight_initializer,
            bias_initializer=bias_initializer,
            use_spectral_normalization=use_spectral_normalization,
        )
        self.ensemble_dim = 0
        self.entropy_penalty = entropy_penalty
        torch.nn.ModuleList([self.mlp_ensemble])

    def to(self, device):
        self.mlp_ensemble.to(device)
        return super().to(device)

    def forward(self, x):
        x = torch.cat(x, dim=-1)
        return self.mlp_ensemble(x)
    
    def avg_of_entropy(self, p):
        logp = torch.log(p+1e-8)
        return -(p * logp).sum(dim=-1, keepdim=True).mean(dim=self.ensemble_dim)

    def skill_reward(self, x, z=None):
        logits = self.forward(x)
        p = torch.softmax(logits, dim=-1)

        p_avr = torch.mean(p, dim=self.ensemble_dim)
        logp_avr = torch.log(p_avr+1e-8)

        entropy_of_avg = -torch.sum(p_avr * logp_avr, dim=-1, keepdim=True)

        if z is not None:
            assert logp_avr.shape == z.shape
            diayn_reward = torch.sum(logp_avr*z, dim=-1, keepdim=True) + math.log(z.shape[-1])
        else:
            diayn_reward = logp_avr + math.log(logp_avr.shape[-1])
        disdain_reward = entropy_of_avg - self.avg_of_entropy(p)

        info = {'logp': diayn_reward - math.log(logp_avr.shape[-1]) , 'diayn': diayn_reward, 'disdain': disdain_reward}

        return diayn_reward + self.w_disdain * disdain_reward, logits, info

    def loss(self, logits, z, w_e, reduction='mean'):
        ce = 0
        ce_func = torch.nn.CrossEntropyLoss(reduction='none')
        for e_logits in logits:
            ce+=ce_func(e_logits, z)
        # entropy penalty
        p  = torch.softmax(logits, dim=-1)
        entropy = -torch.sum(p*torch.log(p+1e-8), dim=-1, keepdim=True)
        if reduction=='mean':
            return torch.mean(w_e*(ce[:, None] + entropy.mean(0) * self.entropy_penalty))
        elif reduction=='sum':
            return torch.sum(w_e*(ce[:, None] + entropy.mean(0) * self.entropy_penalty))
            
    def log_probs(self, x):
        """Probabilities for average distribution."""
        logits = self.forward(x)
        return torch.log_softmax(logits, dim=-1)

    def avg_log_probs(self, x):
        """Probabilities for average distribution."""
        logits = self.forward(x)
        avg_logits = torch.mean(logits, dim=self.ensemble_dim)
        return torch.log_softmax(avg_logits, dim=-1)
    
    def log_prob(self, x, y):
        return (self.log_probs(x)*y[None]).sum(dim=-1, keepdim=True)

    def probs(self, x):
        return torch.exp(self.log_probs(x))
