# Adapted from https://openreview.net/forum?id=-N7PBXqOUJZ
# Modified for handling irregularly sampled time-series

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


def gaussian_init_(n_units, std=1):
    sampler = torch.distributions.Normal(
        torch.Tensor([0]), torch.Tensor([std / n_units])
    )
    A_init = sampler.sample((n_units, n_units))[..., 0]
    return A_init


class LipschitzRNN(nn.Module):
    def __init__(
        self,
        n_in,
        n_units,
        n_out,
        return_sequences,
        beta=0.8,
        gamma=0.01,
        tau=1.0,
        pi=0.0,
        init_std=1,
        alpha=1,
    ):
        super(LipschitzRNN, self).__init__()

        self.n_out = n_out
        self.n_units = n_units
        self.return_sequences = return_sequences
        self.gamma = gamma
        self.beta = beta
        self.tau = tau
        self.alpha = alpha
        self.pi = pi

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        self.E = nn.Linear(n_in, n_units)
        self.D = nn.Linear(n_units, n_out)
        self.I = torch.eye(n_units)

        self.C = nn.Parameter(gaussian_init_(n_units, std=init_std))
        self.B = nn.Parameter(gaussian_init_(n_units, std=init_std))

    def forward(self, x, timespans, mask=None):
        batch_size = x.size(0)
        seq_len = x.size(1)
        device = x.device
        h = torch.zeros(batch_size, self.n_units).to(device)
        I = self.I.to(device)

        # Precompute matrices
        A = (
            self.beta * (self.B - self.B.transpose(1, 0))
            + (1 - self.beta) * (self.B + self.B.transpose(1, 0))
            - self.gamma * I
        )
        W = (
            self.beta * (self.C - self.C.transpose(1, 0))
            + (1 - self.beta) * (self.C + self.C.transpose(1, 0))
            - self.gamma * I
        )

        outputs = []
        last_output = torch.zeros((batch_size, self.n_out), device=device)
        for t in range(seq_len):
            ts = self.tau * timespans[:, t].view(batch_size, 1)
            x_in = x[:, t].view(batch_size, -1)
            z = self.E(x_in)
            h = (
                h
                + ts * self.alpha * torch.matmul(h, A)
                + ts * self.tanh(torch.matmul(h, W) + z)
            )
            current_output = out = self.D(h)

            outputs.append(current_output)
            if mask is not None:
                cur_mask = mask[:, t].view(batch_size, 1)
                last_output = cur_mask * current_output + (1.0 - cur_mask) * last_output
            else:
                last_output = current_output

        if self.return_sequences:
            outputs = torch.stack(outputs, dim=1)  # return entire sequence
        else:
            outputs = last_output  # only last item
        return outputs
