# Copyright 2021 The ODE-LSTM Authors. All Rights Reserved.
from functools import partial

import pytorch_lightning as pl
import torch
import numpy as np
from torch import nn
from torchdyn.core import NeuralDE
from torchdyn.nn import Augmenter, DepthCat, Fourier, GalLinear


class Transformer(nn.Module):
    def __init__(self, input_size: int, num_classes=29, units=128, num_heads=16, ff_dim=512, num_layers=3,
                 dropout=0.5):
        super().__init__()
        self.encoder_stack = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=units,
                nhead=num_heads,
                dim_feedforward=ff_dim,
                dropout=dropout,
                norm_first=True,
                batch_first=True
            ),
            num_layers=num_layers)

        # We map our (B, S, 5) inputs to (B, S, model_dim) representations
        self.encoder = nn.Linear(in_features=input_size, out_features=units)

        # finally, we map our (B, S, model_dim) representations to (B, S, num_classes)
        self.decoder = nn.Linear(units, num_classes)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        embed = self.encoder(features)  # map features to model_dim
        memory = self.encoder_stack(embed, mask=None)  # we do not need to mask, all features available
        logits = self.decoder(memory)  # decode representations to class scores...
        # logits = logits.mean(axis=1) # and mean them along the sequence dimension.
        return memory


class BaseRNNCell(pl.LightningModule):
    def __init__(self, units, input_size):
        super().__init__()
        self.units = units
        self.input_size = input_size

    def get_initial_state(self, batch_size):
        return [torch.zeros((batch_size, self.units), device=self.device, dtype=self.dtype)]


class RNNWrapper(pl.LightningModule):

    def __init__(self, cell):

        super().__init__()
        self.cell = cell

    def forward(self, inputs, hx=None):
        if len(inputs.shape) == 2:
            # ij just 2 dims given assume single batch
            inputs = inputs[None, :, :]
        if hx is None:
            hidden_state = self.get_initial_state(inputs.shape[0])
        else:
            hidden_state = [hx]
        outputs = []
        # unroll RNN
        for i in range(inputs.shape[-2]):
            out, hidden_state = self.cell(inputs[..., i, :], hidden_state)
            outputs.append(out)
            # hidden_state = hidden_state[-1]
        return torch.stack(outputs, 1), hidden_state

    def get_initial_state(self, batch_size):
        return self.cell.get_initial_state(batch_size)




class MLP(torch.nn.Module):
    def __init__(self, units, input_size, num_layers=2, activation=nn.ReLU):
        super().__init__()

        layer_sizes = [input_size] + [units] * num_layers
        layers = []
        for i, o in zip(layer_sizes[:-1], layer_sizes[1:]):
            layer = nn.Sequential(
                nn.Linear(in_features=i, out_features=o),
                activation()
            )
            layers.append(layer)

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class ExtractLSTMOutput(torch.nn.Module):
    def __init__(self, last=False):
        """

        :param last: if True returns just the last values of the sequence, returns the entire sequence otherwise
        """
        super().__init__()
        self.last = last

    def forward(self, x):
        out, _ = x
        if self.last:
            return out[:, -1, :]
        else:
            return out


class LSTM(torch.nn.Module):
    def __init__(self, units, input_size):
        super(LSTM, self).__init__()

        self.cell = torch.nn.LSTM(input_size=input_size, hidden_size=units, batch_first=True)
        self.extraction = ExtractLSTMOutput()

    def forward(self, inputs, hx=None):
        return self.cell(inputs, hx)


class CTRNNCell(BaseRNNCell):
    def __init__(self, units, input_size, num_unfolds=10):
        super(CTRNNCell, self).__init__(units, input_size)
        self.num_unfolds = num_unfolds

        self.tau = torch.nn.Parameter(torch.Tensor(units))
        torch.nn.init.zeros_(self.tau)
        self.kernel = torch.nn.Parameter(torch.Tensor(input_size, units))
        torch.nn.init.xavier_uniform_(self.kernel)
        self.recurrent_kernel = torch.nn.Parameter(torch.Tensor(units, units))
        torch.nn.init.orthogonal_(self.recurrent_kernel)
        self.bias = torch.nn.Parameter(torch.Tensor(self.units))
        torch.nn.init.zeros_(self.bias)
        self.scale = torch.nn.Parameter(torch.Tensor(self.units))
        torch.nn.init.ones_(self.scale)

    def dfdt(self, inputs, hidden_state):
        h_in = torch.matmul(inputs, self.kernel)
        h_rec = torch.matmul(hidden_state, self.recurrent_kernel)
        dh = self.scale * torch.tanh(h_in + h_rec + self.bias)
        dh = dh - hidden_state * self.tau
        return dh

    def solve_euler(self, inputs, hidden_state):
        hx = hidden_state[0]
        for i in range(self.num_unfolds):
            hx = self.dfdt(inputs, hx)
        return hx, [hx]

    def forward(self, inputs, hx=None):
        return self.solve_euler(inputs, hx)


class CTRNN(RNNWrapper):
    def __init__(self, units, input_size, num_unfolds=10):
        super(CTRNN, self).__init__(CTRNNCell(units, input_size, num_unfolds))


class ODELSTM(torch.nn.Module):
    def __init__(self, units, input_size):
        self.units = units
        self.state_size = (units, units)
        self.ctrnn = CTRNN(self.units, input_size, num_unfolds=4)
        super(ODELSTM, self).__init__()

        self.input_kernel = torch.nn.Parameter(torch.Tensor(input_size, 4 * self.units))
        torch.nn.init.xavier_uniform_(self.input_kernel)
        self.recurrent_kernel = torch.nn.Parameter(torch.Tensor(self.units, 4 * self.units))
        torch.nn.init.orthogonal_(self.recurrent_kernel)
        self.bias = torch.nn.Parameter(torch.Tensor(4 * self.units))
        torch.nn.init.zeros_(self.bias)

        self.built = True

    def get_initial_state(self, batch_size=None):
        return (
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
        )

    def forward(self, inputs, states):
        cell_state, ode_state = states
        elapsed = 1.0
        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]

        z = (
                torch.matmul(inputs, self.input_kernel)
                + torch.matmul(ode_state, self.recurrent_kernel)
                + self.bias
        )
        i, ig, fg, og = torch.chunk(z, 4, dim=-1)

        input_activation = torch.tanh(i)
        input_gate = torch.sigmoid(ig)
        forget_gate = torch.sigmoid(fg + 3.0)
        output_gate = torch.sigmoid(og)

        new_cell = cell_state * forget_gate + input_activation * input_gate
        ode_input = torch.tanh(new_cell) * output_gate

        ode_output, new_ode_state = self.ctrnn([ode_input, elapsed], [ode_state])

        return ode_output, [new_cell, new_ode_state[0]]


class ExtractNODEOutput(torch.nn.Module):
    def __init__(self, last=False):
        """

        :param last: if True returns just the last values of the sequence, returns the entire sequence otherwise
        """
        super().__init__()
        self.last = last

    def forward(self, x):
        _, out = x
        if self.last:
            return out[-1]
        else:
            return out[:-1].transpose(0, 1)


class ANODE(torch.nn.Module):
    def __init__(self, units, input_size, mlp_layers=2, augment_dims=10, num_unfolds=1, solver='dopri5',
                 sensitivity='adjoint', **kwargs):
        super().__init__()
        self.t_span = torch.linspace(0, 1, 1 + num_unfolds)
        self.solver = solver
        self.sensitivity = sensitivity
        # Define NeuralDE
        self.node = NeuralDE(MLP(units + augment_dims, units + augment_dims, mlp_layers), solver=self.solver,
                             sensitivity=self.sensitivity)
        self.model = nn.Sequential(nn.Linear(input_size, units),
                                   Augmenter(augment_dims=augment_dims),
                                   self.node,
                                   ExtractNODEOutput(),
                                   nn.Linear(units + augment_dims, units))
        self.node.t_span = self.t_span

    def forward(self, inputs, hx=None):
        if len(inputs.shape) == 2:
            # ij just 2 dims given assume single batch
            inputs = inputs[None, :, :]
        outputs = []
        for i in range(inputs.shape[-2]):
            out = self.model(inputs[..., i, :])
            outputs.append(out)
            # hidden_state = hidden_state[-1]
        return torch.hstack(outputs)


class GalNODE(ANODE):
    def __init__(self, units, input_size, **kwargs):
        super(GalNODE, self).__init__(**kwargs)
        self.units = units

        # vector field parametrized by a NN with "GalLinear" layer
        # notice how DepthCat is still used since Galerkin layers make use of `s` (though in a different way compared
        # to concatenation)
        func = nn.Sequential(DepthCat(1),
                             GalLinear(input_size, units, expfunc=Fourier(5)),
                             nn.Tanh(),
                             nn.Linear(units, input_size))

        # Define NeuralDE
        self.node = self.model = NeuralDE(func, solver='dopri5', sensitivity='adjoint')


class CTGRU(torch.nn.Module):
    # https://arxiv.org/abs/1710.04110
    def __init__(self, units, input_size, M=8, **kwargs):
        super(CTGRU, self).__init__(**kwargs)

        self.units = units
        self.M = M
        self.state_size = units * self.M

        # Pre-computed tau table (as recommended in paper)
        self.ln_tau_table = torch.empty(self.M)
        self.tau_table = torch.empty(self.M)
        tau = 1.0
        for i in range(self.M):
            self.ln_tau_table[i] = np.log(tau)
            self.tau_table[i] = tau
            tau = tau * (10.0 ** 0.5)

        self.retrieval_layer = torch.nn.Linear(self.units + input_size, self.units * self.M)

        self.detect_layer = torch.nn.Linear(self.units + input_size, self.units)
        self.update_layer = torch.nn.Linear(self.units + input_size, self.units * self.M)

    def get_initial_state(self, batch_size):
        return [torch.zeros([batch_size, self.units, self.M])]

    def forward(self, inputs, hx=None):

        batch_dim = inputs.shape[0]

        elapsed = 1.0
        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]

        if hx is None:
            h_hat = self.get_initial_state(inputs.shape[0])[0]
        else:
            # States is actually 2D
            h_hat = hx[0]

        h = torch.sum(h_hat, dim=2)

        # Retrieval
        fused_input = torch.cat([inputs, h], dim=-1)
        ln_tau_r = self.retrieval_layer(fused_input)
        ln_tau_r = torch.reshape(ln_tau_r, shape=[batch_dim, self.units, self.M])
        sf_input_r = -torch.square(ln_tau_r - self.ln_tau_table)
        rki = torch.softmax(sf_input_r, dim=2)

        q_input = torch.sum(rki * h_hat, dim=2)
        reset_value = torch.cat([inputs, q_input], dim=1)
        qk = torch.tanh(self.detect_layer(reset_value))
        qk = torch.reshape(qk, [batch_dim, self.units, 1])  # in order to broadcast

        ln_tau_s = self.update_layer(fused_input)
        ln_tau_s = torch.reshape(ln_tau_s, shape=[batch_dim, self.units, self.M])
        sf_input_s = -torch.square(ln_tau_s - self.ln_tau_table)
        ski = torch.softmax(sf_input_s, dim=2)

        # Now the elapsed time enters the state update
        base_term = (1 - ski) * h_hat + ski * qk
        exp_term = torch.exp(-elapsed / self.tau_table)
        h_hat_next = base_term * exp_term

        # Compute new state
        h_next = torch.sum(h_hat_next, dim=2)
        return h_next, [h_hat_next]


class ElmanCell(BaseRNNCell):
    def __init__(self, units, input_size):
        super().__init__(units, input_size)
        self.rec_layer = torch.nn.Linear(units + input_size, units)
        self.out_layer = torch.nn.Linear(units, units)

    def forward(self, inputs, hidden_states):
        fused_input = torch.cat([inputs, hidden_states[0]], dim=-1)
        new_states = torch.tanh(self.rec_layer(fused_input))
        outputs = self.out_layer(new_states)

        return outputs, [new_states]


class ElmanRNN(RNNWrapper):
    def __init__(self, units, input_size):
        super(ElmanRNN, self).__init__(cell=ElmanCell(units, input_size))


class BidirectionalRNN(torch.nn.Module):
    def __init__(self, units, input_size, **kwargs):
        self.units = units
        self.state_size = (units, units, units)

        input_dim = input_size[-1]
        if isinstance(input_size[0], tuple):
            # Nested tuple
            input_dim = input_size[0][-1]

        self.ctrnn = torch.nn.RNNCell(input_size=input_dim, hidden_size=self.units)
        self.lstm = torch.nn.LSTMCell(input_size=input_dim, hidden_size=self.units)

        super(BidirectionalRNN, self).__init__(**kwargs)

        # self.out_layer = torch.nn.Linear(2*self.hidden_size, input_size)
        fused_dim = ((input_dim + self.units,), (1,))
        self.lstm.build(fused_dim)
        self.ctrnn.build(fused_dim)

    def forward(self, inputs, states):
        elapsed = 1.0
        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]

        lstm_state = [states[0], states[1]]
        lstm_input = [torch.cat([inputs, states[2]], dim=-1), elapsed]
        ctrnn_state = [states[2]]
        ctrnn_input = [torch.cat([inputs, states[1]], dim=-1), elapsed]

        lstm_out, new_lstm_states = self.lstm(lstm_input, lstm_state)
        ctrnn_out, new_ctrnn_state = self.ctrnn(ctrnn_input, ctrnn_state)

        fused_output = lstm_out + ctrnn_out
        return (
            fused_output,
            [new_lstm_states[0], new_lstm_states[1], new_ctrnn_state[0]],
        )


class GRUDCell(BaseRNNCell):
    # Implemented according to
    # https://www.nature.com/articles/s41598-018-24271-9.pdf
    # without the masking

    def __init__(self, units, input_size):
        self.units = units
        self.state_size = units
        super(GRUDCell, self).__init__(units, input_size)

        if isinstance(input_size, tuple):
            # Nested tuple
            input_dim = input_size[0][-1]

        self._reset_gate = torch.nn.Linear(self.units + input_size, self.units)
        torch.nn.init.xavier_uniform_(self._reset_gate.weight)
        self._detect_signal = torch.nn.Linear(self.units + input_size, self.units)
        torch.nn.init.xavier_uniform_(self._detect_signal.weight)
        self._update_gate = torch.nn.Linear(self.units + input_size, self.units)
        torch.nn.init.xavier_uniform_(self._update_gate.weight)
        self._d_gate = torch.nn.Linear(input_size, self.units)
        torch.nn.init.xavier_uniform_(self._d_gate.weight)
        self.d_gate_relu = torch.nn.ReLU()

    def forward(self, inputs, hx=None):

        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]
        else:
            elapsed = torch.ones_like(inputs)

        dt = self.d_gate_relu(self._d_gate(elapsed))
        gamma = torch.exp(-dt)
        h_hat = hx[0] * gamma

        fused_input = torch.cat([inputs, h_hat], dim=-1)
        rt = torch.sigmoid(self._reset_gate(fused_input))
        zt = torch.sigmoid(self._update_gate(fused_input))

        reset_value = torch.cat([inputs, rt * h_hat], dim=-1)
        h_tilde = torch.tanh(self._detect_signal(reset_value))

        # Compute new state
        ht = zt * h_hat + (1.0 - zt) * h_tilde

        return ht, [ht]


class GRUD(RNNWrapper):
    def __init__(self, units, input_size):
        super(GRUD, self).__init__(cell=GRUDCell(units, input_size))


class PhasedLSTMCell(pl.LightningModule):
    # Implemented according to
    # https://papers.nips.cc/paper/6310-phased-lstm-accelerating-recurrent-network-training-for-long-or-event-based
    # -sequences.pdf

    def __init__(self, units, input_size):
        self.units = units
        self.state_size = (units, units)
        self.initializer = "glorot_uniform"
        self.recurrent_initializer = "orthogonal"
        super(PhasedLSTMCell, self).__init__()

        if isinstance(input_size, tuple):
            # Nested tuple
            input_size = input_size[0][-1]

        self.input_kernel = torch.nn.Parameter(torch.Tensor(input_size, 4 * self.units))
        torch.nn.init.xavier_uniform_(self.input_kernel)
        self.recurrent_kernel = torch.nn.Parameter(torch.Tensor(self.units, 4 * self.units))
        torch.nn.init.orthogonal_(self.recurrent_kernel)
        self.bias = torch.nn.Parameter(torch.Tensor(4 * self.units))
        torch.nn.init.zeros_(self.bias)
        self.tau = torch.nn.Parameter(torch.Tensor(1))
        torch.nn.init.zeros_(self.tau)
        self.ron = torch.nn.Parameter(torch.Tensor(1))
        torch.nn.init.zeros_(self.ron)
        self.s = torch.nn.Parameter(torch.Tensor(1))
        torch.nn.init.zeros_(self.s)

    def get_initial_state(self, batch_size=None):
        return (
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
        )

    def forward(self, inputs, states):
        cell_state, hidden_state = states
        elapsed = 1.0
        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]

        # Leaky constant taken from the paper
        alpha = 0.001
        # Make sure these values are positive
        tau = torch.nn.Softplus()(self.tau)
        s = torch.nn.Softplus()(self.s)
        ron = torch.nn.Softplus()(self.ron)

        phit = ((elapsed - s) % tau) / tau
        kt = torch.where(
            torch.less(phit, 0.5 * ron),
            2 * phit * ron,
            torch.where(torch.less(phit, ron), 2.0 - 2 * phit / ron, alpha * phit),
        )

        z = (
                torch.matmul(inputs, self.input_kernel)
                + torch.matmul(hidden_state, self.recurrent_kernel)
                + self.bias
        )
        i, ig, fg, og = torch.chunk(z, 4, dim=-1)

        input_activation = torch.tanh(i)
        input_gate = torch.sigmoid(ig)
        forget_gate = torch.sigmoid(fg + 1.0)
        output_gate = torch.sigmoid(og)

        c_tilde = cell_state * forget_gate + input_activation * input_gate
        c = kt * c_tilde + (1.0 - kt) * cell_state

        h_tilde = torch.tanh(c_tilde) * output_gate
        h = kt * h_tilde + (1.0 - kt) * hidden_state

        return h, [c, h]


class PhasedLSTM(RNNWrapper):
    def __init__(self, units, input_size):
        super(PhasedLSTM, self).__init__(cell=PhasedLSTMCell(units, input_size))


class GRUODE(torch.nn.Module):
    # Implemented according to
    # https://arxiv.org/pdf/1905.12374.pdf
    # without the Bayesian stuff

    def __init__(self, units, input_size, num_unfolds=4):
        self.units = units
        self.num_unfolds = num_unfolds
        self.state_size = units
        super(GRUODE, self).__init__()

        input_dim = input_size[-1]
        if isinstance(input_size[0], tuple):
            # Nested tuple
            input_dim = input_size[0][-1]

        reset_linear = torch.nn.Linear(units + input_size, self.units)
        self.reset_gate = torch.nn.Sequential(reset_linear,
                                              torch.nn.Sigmoid())
        torch.nn.init.ones_(reset_linear.bias)
        self.detect_signal = torch.nn.Sequential(torch.nn.Linear(units + input_size, self.units),
                                                 torch.nn.Tanh())
        self.update_gate = torch.nn.Sequential(torch.nn.Linear(units + input_size, self.units),
                                               torch.nn.Sigmoid())

    def _dh_dt(self, inputs, states):
        fused_input = torch.cat([inputs, states], dim=-1)
        rt = self.reset_gate(fused_input)
        zt = self.update_gate(fused_input)

        reset_value = torch.cat([inputs, rt * states], dim=-1)
        gt = self.detect_signal(reset_value)

        # Compute new state
        dhdt = (1.0 - zt) * (gt - states)
        return dhdt

    def euler(self, inputs, hidden_state, delta_t):
        dy = self._dh_dt(inputs, hidden_state)
        return hidden_state + delta_t * dy

    def forward(self, inputs, states):
        elapsed = 1.0
        if (isinstance(inputs, tuple) or isinstance(inputs, list)) and len(inputs) > 1:
            elapsed = inputs[1]
            inputs = inputs[0]

        delta_t = elapsed / self.num_unfolds
        hidden_state = states[0]
        for i in range(self.num_unfolds):
            hidden_state = self.euler(inputs, hidden_state, delta_t)
        return hidden_state, [hidden_state]


class HawkesLSTMCell(pl.LightningModule):
    # https://papers.nips.cc/paper/7252-the-neural-hawkes-process-a-neurally-self-modulating-multivariate-point-process.pdf
    def __init__(self, units, input_size):
        self.units = units  # state is a tripple
        super(HawkesLSTMCell, self).__init__()

        self.input_kernel = torch.nn.Parameter(torch.Tensor(input_size, 7 * self.units))
        torch.nn.init.xavier_uniform_(self.input_kernel)
        self.recurrent_kernel = torch.nn.Parameter(torch.Tensor(self.units, 7 * self.units))
        torch.nn.init.orthogonal_(self.recurrent_kernel)
        self.bias = torch.nn.Parameter(torch.Tensor(7 * self.units))
        torch.nn.init.zeros_(self.bias)

    def get_initial_state(self, batch_size):
        return (
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
            torch.zeros([batch_size, self.units], dtype=self.dtype, device=self.device),
        )

    def forward(self, inputs, hx):
        c, c_bar, h = hx  # Is the input
        delta_t = 1  # is the elapsed time

        z = (
                torch.matmul(inputs, self.input_kernel)
                + torch.matmul(h, self.recurrent_kernel)
                + self.bias
        )
        i, ig, fg, og, ig_bar, fg_bar, d = torch.chunk(z, 7, dim=-1)

        input_activation = torch.tanh(i)
        input_gate = torch.sigmoid(ig)
        input_gate_bar = torch.sigmoid(ig_bar)
        forget_gate = torch.sigmoid(fg)
        forget_gate_bar = torch.sigmoid(fg_bar)
        output_gate = torch.sigmoid(og)
        delta_gate = torch.nn.Softplus()(d)

        new_c = c * forget_gate + input_activation * input_gate
        new_c_bar = c_bar * forget_gate_bar + input_activation * input_gate_bar

        c_t = new_c_bar + (new_c - new_c_bar) * torch.exp(-delta_gate * delta_t)
        output_state = torch.tanh(c_t) * output_gate

        return output_state, [new_c, new_c_bar, output_state]


class HawkesLSTM(RNNWrapper):
    def __init__(self, units, input_size):
        super(HawkesLSTM, self).__init__(cell=HawkesLSTMCell(units, input_size))
