import torch
import torchdiffeq

from . import common
from . import node


class ODERNN(torch.nn.Module):
    def __init__(self, in_size, out_size, hidden_size, hidden_hidden_size, num_layers):
        super(ODERNN, self).__init__()

        self._hidden_size = hidden_size

        self._gru_cell = torch.nn.GRUCell(input_size=in_size, hidden_size=hidden_size)
        self._func = node.ODEFunc(hidden_size=hidden_size, hidden_hidden_size=hidden_hidden_size, num_layers=num_layers)
        self._readout = torch.nn.Linear(hidden_size, out_size)

    def forward(self, t, x):
        batch_size = x.size(0)

        h = torch.zeros(batch_size, self._hidden_size, dtype=x.dtype, device=x.device)
        x_unbound = x.unbind(dim=1)
        h = self._gru_cell(x_unbound[0], h)
        prev_ti = t[0]
        options = dict(step_size=t[1:].sub(t[:-1]).min())
        for ti, xi in zip(t[1:], x_unbound[1:]):
            h = torchdiffeq.odeint(self._func, h, torch.stack([prev_ti, ti]), method='midpoint', options=options)[1]
            h = self._gru_cell(xi, h)
            prev_ti = ti

        out = self._readout(h)
        return out


class LatentODE(common.VAE):
    def __init__(self, input_channels,
                 encoder_hidden_channels, encoder_hidden_hidden_channels, encoder_num_layers,
                 context_channels,
                 decoder_hidden_channels, decoder_hidden_hidden_channels, decoder_num_layers,
                 label_channels):
        super(LatentODE, self).__init__()
        self._context_channels = context_channels

        self._encoder = ODERNN(in_size=input_channels + label_channels + 1,
                               out_size=2 * context_channels,
                               hidden_size=encoder_hidden_channels,
                               hidden_hidden_size=encoder_hidden_hidden_channels,
                               num_layers=encoder_num_layers)
        self._decoder = node.NeuralODE(in_size=context_channels + label_channels,
                                       out_size=input_channels,
                                       hidden_size=decoder_hidden_channels,
                                       hidden_hidden_size=decoder_hidden_hidden_channels,
                                       num_layers=decoder_num_layers)

    def train_model(self, t, x, y, penalty, seed=None):
        # t has shape (seq_len,)
        # x has shape (batch_size, seq_len, input_channels)
        # y has shape (batch_size, label_channels)

        if seed is None:
            seed = self.generate_seed()
        generator = torch.Generator(y.device).manual_seed(seed)

        batch_size = y.size(0)
        seq_len = t.size(0)

        t_ = t.unsqueeze(-1).unsqueeze(0).repeat(batch_size, 1, 1)
        y_ = y.unsqueeze(-2).repeat(1, seq_len, 1)
        x_ = torch.cat([x, t_, y_], dim=-1)

        out = self._encoder(t, x_)
        mean, logvar = out[:, :self._context_channels], out[:, self._context_channels:]
        std = logvar.exp().sqrt()
        context = mean + torch.randn(mean.shape, dtype=mean.dtype, device=mean.device, generator=generator) * std

        context_ = torch.cat([context, y], dim=1)
        pred_x = self._decoder(t, context_)

        # Autoencoder loss
        loss = torch.nn.functional.mse_loss(x, pred_x)
        # Variational loss
        prior = torch.distributions.Normal(loc=torch.zeros_like(mean), scale=torch.ones_like(std))
        if penalty:
            penalty_ = 0.1 * torch.distributions.kl_divergence(prior, torch.distributions.Normal(loc=mean, scale=std)).mean()
        else:
            penalty_ = 0
        return loss, penalty_

    def generate_sample(self, t, y, seed=None):
        if seed is None:
            seed = self.generate_seed()
        generator = torch.Generator(y.device).manual_seed(seed)
        context = torch.randn(y.size(0), self._context_channels, dtype=y.dtype, device=y.device, generator=generator)
        context_ = torch.cat([context, y], dim=1)
        return self._decoder(t, context_)
