import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from lfrl.torch.networks import Mlp
import lfrl.torch.pytorch_util as ptu


class GMM(Mlp):

    def __init__(
            self,
            input_dim,
            output_dim,
            hidden_sizes,
            n_components=1,
            max_logstd=0.5,
            min_logstd=-5,
            **kwargs
    ):
        # each head outputs (mean, logstd); we must also output logits

        super().__init__(
            hidden_sizes,
            input_size=input_dim,
            output_size=2 * n_components * output_dim + n_components,
            **kwargs
        )

        self.n_components = n_components

        self.max_logstd = nn.Parameter(
            ptu.ones(output_dim * n_components) * max_logstd, requires_grad=False)
        self.min_logstd = nn.Parameter(
            ptu.ones(output_dim * n_components) * min_logstd, requires_grad=False)

    def forward(self, input, deterministic=False, return_dist=False):
        assert len(input.shape) > 1, 'not implemented yet for singledim input'

        output = super().forward(input)
        output, logits = output[:,:-self.n_components], output[:,-self.n_components:]
        mean, logstd = torch.chunk(output, 2, dim=-1)

        logstd = self.max_logstd - F.softplus(self.max_logstd - logstd)
        logstd = self.min_logstd + F.softplus(logstd - self.min_logstd)

        mean = torch.reshape(mean, (self.n_components, mean.shape[0], -1))
        logstd = torch.reshape(logstd, (self.n_components, logstd.shape[0], -1))
        logits = logits.permute(1, 0)

        if deterministic:
            if return_dist:
                return mean, logstd, logits  # NOTE: we return the logstd here...
            else:
                return mean

        std = logstd.exp()
        gmm = self.get_distribution(mean, logstd, logits)
        sample = gmm.sample()

        if return_dist:
            return sample, mean, std, logits  # but the std here.
        else:
            return sample

    def get_distribution(self, mean, logstd, logits):
        mean = mean.permute(1, 0, 2)  # (B, nc, N)
        logstd = logstd.permute(1, 0, 2)  # (B, nc, N)
        logits = logits.permute(1, 0)  # (B, nc)

        mix = torch.distributions.Categorical(logits=logits)
        comp = torch.distributions.Normal(mean, logstd.exp())
        comp = torch.distributions.independent.Independent(comp, 1)
        gmm = torch.distributions.mixture_same_family.MixtureSameFamily(mix, comp)

        return gmm

    def get_log_prob(self, input, output):
        mean, logstd, logits = self.forward(input, deterministic=True, return_dist=True)
        gmm = self.get_distribution(mean, logstd, logits)
        return gmm.log_prob(output).unsqueeze(dim=-1)

    def get_loss(self, input, output, weights=None):
        log_probs = self.get_log_prob(input, output)
        if weights is not None:
            log_probs = log_probs * weights.view(1, -1)
        loss = -log_probs.mean()
        return loss
