import numpy as np
import torch
from torch import nn as nn
import torch.nn.functional as F

import lfrl.torch.pytorch_util as ptu
from lfrl.torch.networks import Mlp, ParallelizedEnsemble, ParallelizedLSTMEnsemble


def swish(x):
    return x * torch.sigmoid(x)


class ProbabilisticEnsemble(ParallelizedEnsemble):

    """
    Probabilistic ensemble for modeling both aleatoric and epistemic uncertainty,
    as described in PETS (Chua et al. 2018). Generates ensemble predictions in
    a single forward call, rather than running each member separately.

    Predicts (r, d, delta s') from (s, a); this contrasts with some other
    model-based work that does not learn r(s, a) or d(s, a).
    """

    def __init__(
        self,
        ensemble_size,
        obs_dim,
        action_dim,
        hidden_sizes,
        noise_clip=None,
        **kwargs
    ):
        super().__init__(
            ensemble_size=ensemble_size,
            hidden_sizes=hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=2*(obs_dim + 2),
            hidden_activation=torch.tanh,
            # hidden_activation=swish,
            **kwargs
        )

        self.obs_dim, self.action_dim = obs_dim, action_dim
        self.output_size = obs_dim + 2
        self.noise_clip = noise_clip

        self.max_logvar = nn.Parameter(ptu.ones(obs_dim+2) / 2)
        self.min_logvar = nn.Parameter(-ptu.ones(obs_dim+2) * 10)

    def forward(self, input, deterministic=False, return_dist=False):
        output = super().forward(input)
        mean, logstd = torch.chunk(output, 2, dim=-1)
        logvar = 2 * logstd

        logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
        logstd = logvar / 2

        if deterministic:
            return (mean, logstd) if return_dist else mean

        std = torch.exp(logstd)
        eps = ptu.randn(std.shape)
        if self.noise_clip is not None:
            eps = torch.fmod(eps, self.noise_clip)
        samples = mean + std * eps

        if return_dist:
            return samples, mean, logstd
        else:
            return samples

    def sample_with_logprob(self, input):
        preds, mean, logstd = self.forward(input, deterministic=False, return_dist=True)
        std = torch.exp(logstd)

        # Standard sampling from ensemble

        inds = torch.randint(0, len(self.elites), input.shape[:-1])
        inds = inds.unsqueeze(dim=-1).to(device=ptu.device)
        inds = inds.repeat(1, preds.shape[2])

        samples = (inds == 0).float() * preds[self.elites[0]]
        for i in range(1, len(self.elites)):
            samples += (inds == i).float() * preds[self.elites[i]]

        # Calculate logprob under each member of ensemble

        _samples = samples.unsqueeze(dim=0)
        _logprobs = -logstd - .5*np.log(2*np.pi) - .5 * ((_samples - mean) / std) ** 2
        _logprobs = torch.sum(_logprobs, dim=-1)
        _probs = torch.exp(_logprobs)

        # For a mixture, p(x) = (1/N) * \sum_i p(x | z_i)

        probs = (1 / len(self.elites)) * _probs[self.elites[0]]
        for i in range(1, len(self.elites)):
            probs += (1 / len(self.elites)) * _probs[self.elites[i]]
        logprobs = torch.log(probs + 1e-6)

        return samples, logprobs, mean, logstd

    def sample_with_disagreement(self, input):
        preds, mean, logstd = self.forward(input, deterministic=False, return_dist=True)

        # Standard uniformly from the ensemble
        inds = torch.randint(0, len(self.elites), input.shape[:-1])

        # Ensure we don't use the same member to estimate disagreement
        inds_b = torch.randint(0, len(self.elites), input.shape[:-1])
        inds_b[inds == inds_b] = torch.fmod(inds_b[inds == inds_b] + 1, len(self.elites))

        # Repeat for multiplication
        inds = inds.unsqueeze(dim=-1).to(device=ptu.device)
        inds = inds.repeat(1, preds.shape[2])
        inds_b = inds_b.unsqueeze(dim=-1).to(device=ptu.device)
        inds_b = inds_b.repeat(1, preds.shape[2])

        samples = (inds == 0).float() * preds[self.elites[0]]
        means_a = (inds == 0).float() * mean[self.elites[0]]
        means_b = (inds_b == 0).float() * mean[self.elites[0]]
        for i in range(1, len(self.elites)):
            samples += (inds == i).float() * preds[self.elites[i]]
            means_a += (inds == i).float() * mean[self.elites[i]]
            means_b += (inds_b == i).float() * mean[self.elites[i]]

        # We use disagreement = mean squared difference in mean predictions
        disagreements = torch.mean((means_a - means_b) ** 2, dim=-1, keepdim=True)

        return samples, disagreements

    def get_loss(self, x, y, split_by_model=False, return_l2_error=False):
        mean, logstd = self.forward(x, deterministic=True, return_dist=True)
        if len(y.shape) < 3:
            y = y.unsqueeze(0).repeat(self.ensemble_size, 1, 1)

        inv_var = torch.exp(-2 * logstd)
        sq_l2_error = (mean - y)**2
        if return_l2_error:
            l2_error = torch.sqrt(sq_l2_error).mean(dim=-1).mean(dim=-1)

        # Loss = Gaussian PDF
        loss = (sq_l2_error * inv_var + 2 * logstd).mean(dim=-1).mean(dim=-1)

        # Encourage lowering max_logvar and increasing min_logvar
        loss += 0.01 * (self.max_logvar.sum() - self.min_logvar.sum())

        if split_by_model:
            losses = [loss[i] for i in range(self.ensemble_size)]
            if return_l2_error:
                l2_errors = [l2_error[i] for i in range(self.ensemble_size)]
                return losses, l2_errors
            else:
                return losses
        else:
            if return_l2_error:
                return loss.sum(), l2_error.mean()
            else:
                return loss.sum()


class ProbabilisticLSTMEnsemble(ParallelizedLSTMEnsemble):

    def __init__(
            self,
            ensemble_size,
            obs_dim,
            action_dim,
            **kwargs
    ):
        super().__init__(
            ensemble_size=ensemble_size,
            input_size=obs_dim + action_dim,
            output_size=2 * (obs_dim + 2),
            **kwargs
        )

        self.obs_dim, self.action_dim = obs_dim, action_dim
        self.output_size = obs_dim + 2

        self.max_logstd = nn.Parameter(
            ptu.ones(obs_dim + 2), requires_grad=False)
        self.min_logstd = nn.Parameter(
            -ptu.ones(obs_dim + 2) * 5, requires_grad=False)

    def forward(self, input, hidden=None, deterministic=False, return_dist=False):
        output, hidden = super().forward(input, hidden=hidden)
        mean, logstd = torch.chunk(output, 2, dim=-1)

        logstd = self.max_logstd - F.softplus(self.max_logstd - logstd)
        logstd = self.min_logstd + F.softplus(logstd - self.min_logstd)

        if deterministic:
            return (mean, logstd, hidden) if return_dist else mean

        std = torch.exp(logstd)
        eps = ptu.randn(std.shape)
        samples = mean + std * eps

        if return_dist:
            return samples, mean, logstd
        else:
            return samples, hidden

    def get_loss(self, xs, ys, split_by_model=False, return_l2_error=False):
        if len(ys.shape) < 4:
            ys = ys.unsqueeze(0).repeat(self.ensemble_size, 1, 1, 1)
        if len(xs.shape) < 4:
            xs = xs.unsqueeze(0).repeat(self.ensemble_size, 1, 1, 1)

        hidden = None
        x = xs[:,:,0]
        for i in range(xs.shape[2]):
            x = xs[:,:,i]
            y = ys[:,:,i]
            mean, logstd, hidden = self.forward(x, hidden=hidden, deterministic=True, return_dist=True)

            inv_var = torch.exp(-2 * logstd)
            sq_l2_error = (mean - y) ** 2
            if return_l2_error:
                l2_error = torch.sqrt(sq_l2_error).mean(dim=-1).mean(dim=-1)

            if i == 0:
                loss = (sq_l2_error * inv_var + 2 * logstd).sum(dim=-1).mean(dim=-1)
            else:
                loss += (sq_l2_error * inv_var + 2 * logstd).sum(dim=-1).mean(dim=-1)

            # note: we could also choose to use the true x

            # have to do this to convert to state
            if i < i-1:
                x = mean + ptu.randn(*mean.shape) * torch.exp(logstd)
                x = torch.cat((x[:,:,2:], xs[:,:,i+1,-self.action_dim:]), dim=-1)
        loss /= len(xs)

        if split_by_model:
            losses = [loss[i] for i in range(self.ensemble_size)]
            if return_l2_error:
                l2_errors = [l2_error[i] for i in range(self.ensemble_size)]
                return losses, l2_errors
            else:
                return losses
        else:
            if return_l2_error:
                return loss.sum(), l2_error.mean()
            else:
                return loss.sum()


class GaussianModel(Mlp):

    def __init__(
            self,
            obs_dim,
            action_dim,
            hidden_sizes,
            hidden_activation=torch.tanh,
            **kwargs
    ):
        super().__init__(
            hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=2*(obs_dim+2),
            hidden_activation=hidden_activation,
            **kwargs
        )

        self.obs_dim, self.action_dim = obs_dim, action_dim
        self.input_size = obs_dim + action_dim
        self.output_size = obs_dim + 2

        self.max_logstd = nn.Parameter(ptu.ones(obs_dim + 2))
        self.min_logstd = nn.Parameter(-ptu.ones(obs_dim + 2) * 5)

    def forward(self, input, deterministic=False, return_dist=False):
        output = super().forward(input)
        mean, logstd = torch.chunk(output, 2, dim=-1)

        logstd = self.max_logstd - F.softplus(self.max_logstd - logstd)
        logstd = self.min_logstd + F.softplus(logstd - self.min_logstd)

        if deterministic:
            if return_dist:
                return mean, logstd
            else:
                return mean

        std = torch.exp(logstd)
        eps = ptu.randn(std.shape)
        sample = mean + eps * std

        if return_dist:
            return sample, mean, std
        else:
            return sample

    def get_loss(self, x, y, return_l2_error=False):
        mean, logstd = self.forward(x, deterministic=True, return_dist=True)

        inv_var = torch.exp(-2 * logstd)
        l2_error = (mean - y)**2

        loss = (l2_error * inv_var + 2 * logstd).sum(dim=-1).mean()
        loss += 0.01 * (self.max_logstd - self.min_logstd).mean()

        if return_l2_error:
            return loss, torch.sqrt(l2_error).mean()
        else:
            return loss
