import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from lfrl.torch.distributions import TanhNormal
from lfrl.torch.networks import Mlp
import lfrl.torch.pytorch_util as ptu


class GaussianModel(Mlp):

    def __init__(
            self,
            input_dim,
            output_dim,
            hidden_sizes,
            max_logstd=0.5,
            min_logstd=-5,
            **kwargs
    ):
        super().__init__(
            hidden_sizes,
            input_size=input_dim,
            output_size=2*output_dim,
            **kwargs
        )

        self.max_logstd = nn.Parameter(
            ptu.ones(output_dim) * max_logstd, requires_grad=False)
        self.min_logstd = nn.Parameter(
            ptu.ones(output_dim) * min_logstd, requires_grad=False)

    def forward(self, input, deterministic=False, return_dist=False, return_log_prob=False):
        output = super().forward(input)
        mean, logstd = torch.chunk(output, 2, dim=-1)

        logstd = self.max_logstd - F.softplus(self.max_logstd - logstd)
        logstd = self.min_logstd + F.softplus(logstd - self.min_logstd)
        std = logstd.exp()

        if deterministic:
            if return_dist:
                return mean, logstd  # NOTE: we return the logstd here...
            else:
                return mean

        comp = torch.distributions.Normal(mean, logstd.exp())
        comp = torch.distributions.independent.Independent(comp, 1)
        sample = comp.sample()

        if return_log_prob:  # awkward
            log_prob = comp.log_prob(output).unsqueeze(-1)
            return sample, sample, sample, log_prob, None

        if return_dist:
            return sample, mean, std  # but the std here.
        else:
            return sample

    def get_log_prob(self, input, output):
        mean, logstd = self.forward(input, deterministic=True, return_dist=True)
        comp = torch.distributions.Normal(mean, logstd.exp())
        comp = torch.distributions.independent.Independent(comp, 1)
        log_prob = comp.log_prob(output).unsqueeze(-1)
        return log_prob

    def get_log_probs(self, input, output):
        return self.get_log_prob(input, output)

    def get_loss(self, input, output, weights=None):
        log_probs = self.get_log_prob(input, output)
        if weights is not None:
            log_probs = log_probs * weights.view(-1, 1)
        loss = -log_probs.mean()
        return loss


class TanhGaussianModel(GaussianModel):

    def forward(self, input, deterministic=False, return_dist=False, return_logprob=False):
        mean, logstd = super().forward(input, deterministic=True, return_dist=True)
        std = torch.exp(logstd)

        if deterministic:
            if return_dist:
                return torch.tanh(mean), logstd
            else:
                return torch.tanh(mean)

        tanh_normal = TanhNormal(mean, std)
        output, pretanh_value = tanh_normal.rsample(return_pretanh_value=True)

        if return_logprob:
            return output, tanh_normal.log_prob(output, pre_tanh_value=pretanh_value).sum(-1, keepdim=True)
        elif return_dist:
            return output, torch.tanh(mean), std
        else:
            return output

    def get_log_prob(self, input, output):
        mean, logstd = super().forward(input, deterministic=True, return_dist=True)
        std = torch.exp(logstd)
        # print(mean.min(), mean.mean(), mean.max(), std.min(), std.mean(), std.max())

        tanh_normal = TanhNormal(mean, std)

        log_prob = tanh_normal.log_prob(output)
        return log_prob.sum(-1, keepdim=True)


class UnconditionalPrior(nn.Module):

    def __init__(self, output_dim):
        super().__init__()

        self.mean = nn.Parameter(ptu.zeros(output_dim), requires_grad=True)
        self.logstd = nn.Parameter(ptu.zeros(output_dim), requires_grad=True)
        self.output_dim = output_dim

    def forward(self, input, deterministic=False, return_dist=False, return_logprob=False):
        mean = self.mean.view(-1, 1).repeat(1, input.shape[0])
        logstd = self.logstd.view(-1, 1).repeat(1, input.shape[0])
        std = torch.exp(logstd)

        if deterministic:
            if return_dist:
                return torch.tanh(mean), logstd
            else:
                return torch.tanh(mean)

        tanh_normal = TanhNormal(mean, std)
        output, pretanh_value = tanh_normal.rsample(return_pretanh_value=True)

        if return_logprob:
            return output, tanh_normal.log_prob(output, pre_tanh_value=pretanh_value)
        elif return_dist:
            return output, torch.tanh(mean), std
        else:
            return output

    def get_log_prob(self, input, output):
        mean = self.mean.view(-1, 1).repeat(1, input.shape[0])
        logstd = self.logstd.view(-1, 1).repeat(1, input.shape[0])
        std = torch.exp(logstd)

        tanh_normal = TanhNormal(mean, std)

        return tanh_normal.log_prob(output)

    def sample(self, n=1, return_log_probs=False):
        # TODO: fix this
        eps = ptu.randn(n, self.output_dim)
        samples = self.mean + torch.exp(self.logstd) * eps
        if return_log_probs:
            log_prob = -self.logstd - .5*np.log(2*np.pi) - .5*((samples-self.mean) / torch.exp(self.logstd)) ** 2
            return samples, log_prob
        else:
            return samples
