import numpy as np
import torch
import torch.nn as nn

import lfrl.torch.pytorch_util as ptu


def swish(x):
    return x * torch.sigmoid(x)


def get_affine_params(ensemble_size, in_features, out_features):

    w = ptu.randn((ensemble_size, in_features, out_features))
    w = torch.fmod(w, 2) * 1.0 / (2.0 * np.sqrt(in_features))
    w = nn.Parameter(w)

    b = nn.Parameter(torch.zeros(ensemble_size, 1, out_features, dtype=torch.float32))

    return w, b


class PtModel(nn.Module):

    def __init__(self, ensemble_size, obs_dim, action_dim, restrict_dim=0, activation_func=torch.tanh):
        super().__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.restrict_dim = restrict_dim

        in_features = obs_dim + action_dim - restrict_dim
        out_features = 2 * (obs_dim + 2)

        self.num_nets = ensemble_size
        self.ensemble_size = ensemble_size
        self.activation_func = activation_func

        self.in_features = in_features
        self.out_features = out_features

        self.lin0_w, self.lin0_b = get_affine_params(ensemble_size, in_features, 512)

        self.lin1_w, self.lin1_b = get_affine_params(ensemble_size, 512, 512)

        self.lin2_w, self.lin2_b = get_affine_params(ensemble_size, 512, 512)

        self.lin3_w, self.lin3_b = get_affine_params(ensemble_size, 512, out_features)

        self.inputs_mu = nn.Parameter(torch.zeros(in_features), requires_grad=False)
        self.inputs_sigma = nn.Parameter(torch.zeros(in_features), requires_grad=False)

        self.max_logvar = nn.Parameter(torch.ones(1, out_features // 2, dtype=torch.float32) / 2.0, requires_grad=False)
        self.min_logvar = nn.Parameter(- torch.ones(1, out_features // 2, dtype=torch.float32) * 10.0, requires_grad=False)

    def get_params(self):
        params = [
            self.lin0_w, self.lin0_b,
            self.lin1_w, self.lin1_b,
            self.lin2_w, self.lin2_b,
            self.lin3_w, self.lin3_b,
            self.max_logvar,
            self.min_logvar
        ]
        params = [p.clone() for p in params]
        return params

    def set_params(self, params):
        self.lin0_w.data = params[0]
        self.lin0_b.data = params[1]
        self.lin1_w.data = params[2]
        self.lin1_b.data = params[3]
        self.lin2_w.data = params[4]
        self.lin2_b.data = params[5]
        self.lin3_w.data = params[6]
        self.lin3_b.data = params[7]
        self.max_logvar.data = params[8]
        self.min_logvar.data = params[9]

    def compute_decays(self):

        lin0_decays = 0.00025 * (self.lin0_w ** 2).sum() / 2.0
        lin1_decays = 0.0005 * (self.lin1_w ** 2).sum() / 2.0
        lin2_decays = 0.0005 * (self.lin2_w ** 2).sum() / 2.0
        lin3_decays = 0.00075 * (self.lin3_w ** 2).sum() / 2.0

        return lin0_decays + lin1_decays + lin2_decays + lin3_decays

    def fit_input_stats(self, data):

        mu = np.mean(data, axis=0, keepdims=True)
        sigma = np.std(data, axis=0, keepdims=True)
        sigma[sigma < 1e-12] = 1.0

        self.inputs_mu.data = ptu.from_numpy(mu)
        self.inputs_sigma.data = ptu.from_numpy(sigma)

    def forward(self, inputs, deterministic=False, return_dist=False):

        # Transform inputs
        inputs = (inputs - self.inputs_mu) / self.inputs_sigma

        if len(inputs.shape) < 3:
            inputs = inputs.unsqueeze(0)
            inputs = inputs.repeat(self.ensemble_size, 1, 1)

        # inputs = inputs[:,:,self.restrict_dim:]

        inputs = inputs.matmul(self.lin0_w) + self.lin0_b
        inputs = self.activation_func(inputs)

        inputs = inputs.matmul(self.lin1_w) + self.lin1_b
        inputs = self.activation_func(inputs)

        inputs = inputs.matmul(self.lin2_w) + self.lin2_b
        inputs = self.activation_func(inputs)

        inputs = inputs.matmul(self.lin3_w) + self.lin3_b

        mean = inputs[:, :, :self.out_features // 2]

        logvar = inputs[:, :, self.out_features // 2:]
        logvar = self.max_logvar - nn.functional.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + nn.functional.softplus(logvar - self.min_logvar)
        logstd = 0.5 * logvar

        if deterministic:
            return (mean, logstd) if return_dist else mean

        std = torch.exp(logstd)
        eps = ptu.randn(std.shape)
        samples = mean + std * eps

        if return_dist:
            return samples, mean, logstd
        else:
            return samples

    def sample(self, input, return_dist=False):
        preds, mean, logstd = self.forward(input, deterministic=False, return_dist=True)

        # Standard uniformly from the ensemble
        inds = torch.randint(0, preds.shape[0], input.shape[:-1])

        # Repeat for multiplication
        inds = inds.unsqueeze(dim=-1).to(device=ptu.device)
        inds = inds.repeat(1, preds.shape[2])

        samples = (inds == 0).float() * preds[0]

        if return_dist:
            return samples, mean, logstd
        else:
            return samples

    def sample_with_disagreement(self, input, return_dist=False):
        preds, mean, logstd = self.forward(input, deterministic=False, return_dist=True)

        # Standard uniformly from the ensemble
        inds = torch.randint(0, preds.shape[0], input.shape[:-1])

        # Ensure we don't use the same member to estimate disagreement
        inds_b = torch.randint(0, mean.shape[0], input.shape[:-1])
        inds_b[inds == inds_b] = torch.fmod(inds_b[inds == inds_b] + 1, mean.shape[0])

        # Repeat for multiplication
        inds = inds.unsqueeze(dim=-1).to(device=ptu.device)
        inds = inds.repeat(1, preds.shape[2])
        inds_b = inds_b.unsqueeze(dim=-1).to(device=ptu.device)
        inds_b = inds_b.repeat(1, preds.shape[2])

        samples = (inds == 0).float() * preds[0]
        means_a = (inds == 0).float() * mean[0]
        means_b = (inds_b == 0).float() * mean[0]
        for i in range(1, preds.shape[0]):
            samples += (inds == i).float() * preds[i]
            means_a += (inds == i).float() * mean[i]
            means_b += (inds_b == i).float() * mean[i]

        # We use disagreement = mean squared difference in mean predictions
        means_a[:,0] = 0  # don't use reward pred for disagreement calculation
        means_b[:,0] = 0
        disagreements = torch.mean((means_a - means_b) ** 2, dim=-1, keepdim=True)

        if return_dist:
            return samples, disagreements, mean, logstd
        else:
            return samples, disagreements

    def get_loss(self, x, y, split_by_model=False, return_l2_error=False):
        # y[:,0] = 0

        mean, logstd = self.forward(x, deterministic=True, return_dist=True)
        if len(y.shape) < 3:
            y = y.unsqueeze(0).repeat(self.ensemble_size, 1, 1)

        inv_var = torch.exp(-2 * logstd)
        sq_l2_error = (mean - y) ** 2
        if return_l2_error:
            l2_error = torch.sqrt(sq_l2_error).mean(dim=-1).mean(dim=-1)

        # Loss = Gaussian PDF
        loss = (sq_l2_error * inv_var + 2 * logstd).mean(dim=-1).mean(dim=-1)

        # Encourage lowering max_logvar and increasing min_logvar
        loss += 0.01 * (self.max_logvar.sum() - self.min_logvar.sum())
        loss += self.compute_decays()

        if split_by_model:
            losses = [loss[i] for i in range(self.ensemble_size)]
            if return_l2_error:
                l2_errors = [l2_error[i] for i in range(self.ensemble_size)]
                return losses, l2_errors
            else:
                return losses
        else:
            if return_l2_error:
                return loss.sum(), l2_error.mean()
            else:
                return loss.sum()
