
import torch
import numpy as np
from torch import nn as nn
from torch.nn import functional as F


def swish(x):
    return x * torch.sigmoid(x)


def get_affine_params(ensemble_size, in_features, out_features):
    w = truncated_normal(size=(ensemble_size, in_features, out_features),
                         std=1.0 / (2.0 * np.sqrt(in_features)))
    w = nn.Parameter(w)

    b = nn.Parameter(torch.zeros(ensemble_size, 1, out_features, dtype=torch.float32))

    return w, b


def truncated_normal(size, std=1.0):
    # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/20
    tensor = torch.zeros(size)
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(0.0)
    return tensor


TORCH_DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


class PtModel(nn.Module):

    def __init__(self, ensemble_size, in_features, out_features):
        super().__init__()

        self.num_nets = ensemble_size

        self.in_features = in_features
        self.out_features = out_features

        self.lin0_w, self.lin0_b = get_affine_params(ensemble_size, in_features, 200)

        self.lin1_w, self.lin1_b = get_affine_params(ensemble_size, 200, 200)

        self.lin2_w, self.lin2_b = get_affine_params(ensemble_size, 200, 200)

        self.lin3_w, self.lin3_b = get_affine_params(ensemble_size, 200, 200)

        self.lin4_w, self.lin4_b = get_affine_params(ensemble_size, 200, out_features)

        self.inputs_mu = nn.Parameter(torch.zeros(in_features), requires_grad=False)
        self.inputs_sigma = nn.Parameter(torch.zeros(in_features), requires_grad=False)

        self.max_logvar = nn.Parameter(torch.ones(1, out_features // 2, dtype=torch.float32) / 2.0)
        self.min_logvar = nn.Parameter(- torch.ones(1, out_features // 2, dtype=torch.float32) * 10.0)

    # def compute_decays(self):
    #
    #     lin0_decays = 0.000025 * (self.lin0_w ** 2).sum() / 2.0
    #     lin1_decays = 0.00005 * (self.lin1_w ** 2).sum() / 2.0
    #     lin2_decays = 0.000075 * (self.lin2_w ** 2).sum() / 2.0
    #     lin3_decays = 0.000075 * (self.lin3_w ** 2).sum() / 2.0
    #     lin4_decays = 0.0001 * (self.lin4_w ** 2).sum() / 2.0
    #
    #     return lin0_decays + lin1_decays + lin2_decays + lin3_decays + lin4_decays

    def fit_input_stats(self, data):
        mu = np.mean(data, axis=0, keepdims=True)
        sigma = np.std(data, axis=0, keepdims=True)
        sigma[sigma < 1e-12] = 1.0

        self.inputs_mu.data = torch.from_numpy(mu).to(TORCH_DEVICE).float()
        self.inputs_sigma.data = torch.from_numpy(sigma).to(TORCH_DEVICE).float()

    def forward(self, inputs, ret_logvar=False):
        # Transform inputs
        inputs = (inputs - self.inputs_mu) / self.inputs_sigma

        inputs = inputs.matmul(self.lin0_w) + self.lin0_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin1_w) + self.lin1_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin2_w) + self.lin2_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin3_w) + self.lin3_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin4_w) + self.lin4_b

        mean = inputs[:, :, :self.out_features // 2]

        logvar = inputs[:, :, self.out_features // 2:]
        logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)

        if ret_logvar:
            return mean, logvar

        return mean, torch.exp(logvar)


class ModelHandler(object):
    def __init__(self):
        self.npart = num_ensembles * 3
        self.model = PtModel(ensemble_size=num_ensembles, in_features=dim_obs + dim_act, out_features=dim_obs * 2)

    def predict_next_obs(self, obs, acs):
        obs = obs[None, ...].repeat(num_ensembles, 1, 1, 1)
        acs = acs[None, ...].repeat(num_ensembles, 1, 1, 1)
        proc_obs = self._expand_to_ts_format(obs)
        acs = self._expand_to_ts_format(acs)
        inputs = torch.cat((proc_obs, acs), dim=-1)
        mean, var = self.model(inputs)
        predictions = mean + torch.randn_like(mean, device=TORCH_DEVICE) * var.sqrt()
        return predictions.mean(0)  # aggregate over the particles

    def _expand_to_ts_format(self, mat):
        dim = mat.shape[-1]

        # [10, 5] -> [2, 5, 1, 5]
        reshaped = mat.view(-1, self.model.num_nets, self.npart // self.model.num_nets, dim)
        transposed = reshaped.transpose(0, 1)
        # After, [5, 2, 1, 5]
        reshaped = transposed.contiguous().view(self.model.num_nets, -1, dim)
        # After. [5, 2, 5]
        return reshaped

    def _flatten_to_matrix(self, ts_fmt_arr):
        dim = ts_fmt_arr.shape[-1]
        reshaped = ts_fmt_arr.view(self.model.num_nets, -1, self.npart // self.model.num_nets, dim)
        transposed = reshaped.transpose(0, 1)
        reshaped = transposed.contiguous().view(-1, dim)
        return reshaped

    def train(self, obs, act, rew):
        _in = torch.cat([obs, act], dim=1)
        model = PtModel(ensemble_size=num_ensembles, in_features=dim_obs + dim_act, out_features=dim_obs * 2)
        model.fit_input_stats(data=_in.numpy())
        loss = 0.01 * (model.max_logvar.sum() - model.min_logvar.sum())

        mean, logvar = model(_in, ret_logvar=True)
        inv_var = torch.exp(-logvar)

        train_losses = ((mean - obs) ** 2) * inv_var + logvar
        train_losses = train_losses.mean(-1).mean(-1).sum()
        loss += train_losses
        return loss


if __name__ == '__main__':
    num_samples = 3
    num_ensembles = 2
    dim_obs, dim_act = 4, 5
    obs = torch.randn(num_samples, dim_obs)
    act = torch.randn(num_samples, dim_act)
    rew = torch.randn(num_samples, 1)

    # import pudb; pudb.start()
    model = ModelHandler()
    out = model.predict_next_obs(obs=obs, acs=act)
    print(out.shape, obs.shape)
    loss = model.train(obs, act, rew)
    print(loss)
