import math
import torch
import torchdiffeq
import torchsde

from . import common
from . import latent_ode
from . import node


class _ACNFFunc(torch.nn.Module):
    def __init__(self, in_size, context_size, hidden_size, num_layers, augment_func):
        super(_ACNFFunc, self).__init__()
        self._in_size = in_size
        self._context_size = context_size

        self._mlp = common.MLP(in_size=1 + in_size + context_size,
                               out_size=in_size,
                               hidden_size=hidden_size,
                               num_layers=num_layers,
                               tanh=True)
        self._augment_func = augment_func

    def forward(self, t, z):
        # t has shape ()
        # z has shape (batch_size, in_size + context_size + in_size + 1)

        with torch.enable_grad():
            state = z[:, :self._in_size]
            context = z[:, self._in_size:self._in_size + self._context_size]
            eps = z[:, self._in_size + self._context_size:2 * self._in_size + self._context_size]
            t_ = t.unsqueeze(0).unsqueeze(0).repeat(z.size(0), 1)

            state_with_grad = state if state.requires_grad else state.detach().requires_grad_()
            f = self._mlp(torch.cat([t_, state_with_grad, context], dim=1))
            grad = torch.autograd.grad(f, state_with_grad, eps, create_graph=True)[0]
            div = (grad * eps).sum(dim=1, keepdim=True)
            return torch.cat([f, self._augment_func(t, context), torch.zeros_like(f), -div], dim=1)


class _ACNF(torch.nn.Module):
    def __init__(self, in_size, context_size, hidden_size, num_layers):
        super(_ACNF, self).__init__()
        self._in_size = in_size
        self._context_size = context_size

        self._augment_func = node.ODEFunc(hidden_size=context_size,
                                          hidden_hidden_size=hidden_size,
                                          num_layers=num_layers)
        self._func = _ACNFFunc(in_size=in_size,
                               context_size=context_size,
                               hidden_size=hidden_size,
                               num_layers=num_layers,
                               augment_func=self._augment_func)

    def forward(self, x, context, logp, generator=None):
        # x is of shape (batch_size, in_size)
        # context is of shape (batch_size, context_size)
        # logp is a function (batch_size, in_size) -> (batch_size,) giving the prior probability density of its input

        t_backward = torch.linspace(1, 0, 2, dtype=x.dtype, device=x.device)
        context_end = torchdiffeq.odeint(self._augment_func, context, t_backward, method='midpoint',
                                         options=dict(step_size=0.25))
        context_end = context_end[1]

        eps = torch.randn(x.shape, dtype=x.dtype, device=x.device, generator=generator)
        zero = torch.zeros(x.size(0), 1, dtype=x.dtype, device=x.device)
        y0 = torch.cat([x, context_end, eps, zero], dim=1)
        t = torch.linspace(0, 1, 2, dtype=x.dtype, device=x.device)
        out = torchdiffeq.odeint(self._func, y0, t, method='midpoint', options=dict(step_size=0.25))
        out = out[1]  # final time
        # out is now of shape (batch_size, in_size + context_size + in_size + 1)

        z = out[:, :self._in_size]
        delta_logp = out[:, -1]  # get log-prob
        logpz = logp(z)
        logpx = logpz - delta_logp
        return logpx.mean()

    def sample(self, noise, context):
        # noise is of shape (batch_size, in_size)
        # context is of shape (batch_size, context_size)

        with torch.no_grad():
            # don't need log-p calculations now
            eps = torch.zeros_like(noise)
            zero = torch.zeros(noise.size(0), 1, dtype=noise.dtype, device=noise.device)

            y1 = torch.cat([noise, context, eps, zero], dim=1)
            t = torch.linspace(1, 0, 2, dtype=noise.dtype, device=noise.device)
            out = torchdiffeq.odeint(self._func, y1, t, method='midpoint', options=dict(step_size=0.25))
            out = out[1]  # "initial" time
            x = out[:, :self._in_size]
            return x


class CTFP(common.VAE):
    def __init__(self, input_channels,
                 encoder_hidden_channels, encoder_hidden_hidden_channels, encoder_num_layers,
                 context_channels,
                 decoder_hidden_channels, decoder_num_layers,
                 label_channels):
        super(CTFP, self).__init__()
        self._input_channels = input_channels
        self._context_channels = context_channels
        self._label_channels = label_channels

        self._encoder = latent_ode.ODERNN(in_size=input_channels + label_channels + 1,
                                          out_size=2 * context_channels,
                                          hidden_size=encoder_hidden_channels,
                                          hidden_hidden_size=encoder_hidden_hidden_channels,
                                          num_layers=encoder_num_layers)
        self._acnf = _ACNF(in_size=input_channels,
                           context_size=context_channels + 1 + label_channels,
                           hidden_size=decoder_hidden_channels,
                           num_layers=decoder_num_layers)

    def _brownian_logp(self, t, z):
        # t is of shape (seq_len,)
        # z is of shape (batch_size * seq_len, input_channels)

        seq_len = t.size(0)
        z = z.reshape(-1, seq_len, self._input_channels)  # z is of shape (batch_size, seq_len, input_channels)
        batch_size = z.size(0)

        zeros = torch.zeros(batch_size, 1, self._input_channels, dtype=z.dtype, device=z.device)
        mean = torch.cat([zeros, z[:, 1:]], dim=1)

        # In the CTFP paper they offset the Brownian motion by 0.2 to get nontrivial variance for the first part.
        # Basically the same as putting a N(0, 0.2) prior on the initial point of the Brownian motion.
        point_two = torch.tensor([0.2], dtype=z.dtype, device=z.device)
        std = torch.cat([point_two, t[1:] - t[:-1]]).sqrt().unsqueeze(0).unsqueeze(-1).expand(batch_size, seq_len, self._input_channels)

        logp = -0.5 * math.log(2 * math.pi) - std.log() - 0.5 * std.reciprocal() * (z - mean).pow(2)
        return logp.reshape(batch_size * seq_len, self._input_channels).sum(dim=1)

    def _augment_context(self, context, t, y):
        batch_size = y.size(0)
        seq_len = t.size(0)

        t_ = t.unsqueeze(-1).unsqueeze(0).repeat(batch_size, 1, 1)
        y_ = y.unsqueeze(-2).repeat(1, seq_len, 1)

        context = context.unsqueeze(1).repeat(1, seq_len, 1)
        context = torch.cat([context, t_, y_], dim=2)
        context = context.view(batch_size * seq_len, self._context_channels + 1 + self._label_channels)

        return context

    def train_model(self, t, x, y, penalty, seed=None):
        # t has shape (seq_len,)
        # x has shape (batch_size, seq_len, input_channels)
        # y has shape (batch_size, label_channels)

        if seed is None:
            seed = self.generate_seed()
        generator = torch.Generator(y.device).manual_seed(seed)

        batch_size = y.size(0)
        seq_len = t.size(0)

        t_ = t.unsqueeze(-1).unsqueeze(0).repeat(batch_size, 1, 1)
        y_ = y.unsqueeze(-2).repeat(1, seq_len, 1)
        x_ = torch.cat([x, t_, y_], dim=-1)
        out = self._encoder(t, x_)
        mean, logvar = out[:, :self._context_channels], out[:, self._context_channels:]
        std = logvar.exp().sqrt()
        context = mean + torch.randn(mean.shape, dtype=mean.dtype, device=mean.device, generator=generator) * std

        context = self._augment_context(context, t, y)

        x_flat = x.view(batch_size * seq_len, self._input_channels)
        logpx = self._acnf(x_flat, context, lambda z: self._brownian_logp(t, z), generator)

        # MLE loss
        loss = -logpx
        # Variational loss
        prior = torch.distributions.Normal(loc=torch.zeros_like(mean), scale=torch.ones_like(std))
        if penalty:
            penalty_ = 0.1 * torch.distributions.kl_divergence(prior, torch.distributions.Normal(loc=mean, scale=std)).mean()
        else:
            penalty_ = 0
        return loss, penalty_

    def generate_sample(self, t, y, seed=None, mean=False):
        if seed is None:
            seed = self.generate_seed()
        generator = torch.Generator(y.device).manual_seed(seed)

        batch_size = y.size(0)
        seq_len = t.size(0)

        if mean:
            noise = torch.zeros(batch_size * seq_len, self._input_channels, dtype=y.dtype, device=y.device)
        else:
            noise_bm = torchsde.BrownianInterval(t[0], t[-1], (batch_size, self._input_channels),
                                                 dtype=y.dtype, device=y.device, entropy=seed)
            noise_bm0 = torch.randn(batch_size, self._input_channels, dtype=y.dtype, device=y.device,
                                    generator=generator) * math.sqrt(0.2)
            noise = torch.stack([noise_bm(ti) for ti in t]) + noise_bm0  # shape (seq_len, batch_size, input_channels)
            noise = noise.transpose(0, 1).reshape(batch_size * seq_len, self._input_channels)

        context = torch.randn(batch_size, self._context_channels, dtype=y.dtype, device=y.device)
        context = self._augment_context(context, t, y)

        out = self._acnf.sample(noise, context)
        out = out.reshape(batch_size, seq_len, self._input_channels)
        return out
