## Reproduce from Pytorch implementation https://github.com/arsedler9/lfads-torch


import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


__all__ = ["lfads"]


class Reparameterize(nn.Module):
    def forward(self, z_mu, z_log_var):
        epsilon = torch.randn_like(z_mu)
        z = z_mu + torch.exp(0.5 * z_log_var) * epsilon

        return z


class ClippedGRU(nn.Module):
    def __init__(self, input_dim, hidden_state_dim, clip_value=float("inf")):
        super().__init__()

        self.cell = nn.GRUCell(input_dim, hidden_state_dim, bias=True)
        self.clip_value = clip_value
    
    def forward(self, x, h_0):
        h = h_0
        output = []

        for t in range(x.size(0)):
            h = self.cell(x[t], h)
            h = torch.clamp(h, -self.clip_value, self.clip_value)
            output.append(h)
        
        output = torch.stack(output, dim=0)
        return output, h


class Encoder(nn.Module):
    def __init__(self, encod_input_dim, g0_enc_dim, g0_dim, con_enc_dim, clip_value=float("inf"), dropout_rate=0.0, use_controller=True):
        super().__init__()
        
        self.f_enc_g0_input = ClippedGRU(encod_input_dim, g0_enc_dim, clip_value)
        self.b_enc_g0_input = ClippedGRU(encod_input_dim, g0_enc_dim, clip_value)

        self.f_enc_g0_input_h0 = nn.Parameter(torch.zeros(1, g0_enc_dim), requires_grad=True)
        self.b_enc_g0_input_h0 = nn.Parameter(torch.zeros(1, g0_enc_dim), requires_grad=True)

        self.enc_g0_mu = nn.Linear(g0_enc_dim * 2, g0_dim)
        self.enc_g0_log_var = nn.Linear(g0_enc_dim * 2, g0_dim)

        self.use_controller = use_controller
        if self.use_controller:
            self.f_enc_con_input = ClippedGRU(encod_input_dim, con_enc_dim, clip_value)
            self.b_enc_con_input = ClippedGRU(encod_input_dim, con_enc_dim, clip_value)

            self.f_enc_con_input_h0 = nn.Parameter(torch.zeros(1, con_enc_dim), requires_grad=True)
            self.b_enc_con_input_h0 = nn.Parameter(torch.zeros(1, con_enc_dim), requires_grad=True)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        T, B = x.size(0), x.size(1)
        x_drop = self.dropout(x)

        f_enc_g0_input_h0 = torch.tile(self.f_enc_g0_input_h0, (B, 1))
        b_enc_g0_input_h0 = torch.tile(self.b_enc_g0_input_h0, (B, 1))

        x_f = x_drop
        x_b = torch.flip(x_drop, [0])
        
        _, g0_input_f = self.f_enc_g0_input(x_f, f_enc_g0_input_h0)
        _, g0_input_b = self.b_enc_g0_input(x_b, b_enc_g0_input_h0)
        ##
        g0_input = torch.cat([g0_input_f, g0_input_b], dim=1)
        g0_input_drop = self.dropout(g0_input)

        g0_mu = self.enc_g0_mu(g0_input_drop)
        g0_log_var = self.enc_g0_log_var(g0_input_drop)

        if self.use_controller:
            f_enc_con_input_h0 = torch.tile(self.f_enc_con_input_h0, (B, 1))
            b_enc_con_input_h0 = torch.tile(self.b_enc_con_input_h0, (B, 1))
            
            con_input_f, _ = self.f_enc_con_input(x_f, f_enc_con_input_h0)
            con_input_b, _ = self.b_enc_con_input(x_b, b_enc_con_input_h0)
            con_input_b = torch.flip(con_input_b, [0])
            con_input = torch.cat([con_input_f, con_input_b], dim=2)
        else:
            con_input = torch.zeros(T, B, 0).to(x)

        return g0_mu, g0_log_var, con_input


class NormalizedLinear(nn.Linear):
    def forward(self, x):
        normed_weight = F.normalize(self.weight, p=2, dim=1)
        return F.linear(x, normed_weight, self.bias)


class Decoder(nn.Module):
    def __init__(self, con_input_dim, con_dim, u_dim, g0_dim, factor_dim, clip_value=float("inf"), dropout_rate=0.0, use_controller=True):
        super().__init__()
        
        self.gen_cell = nn.GRUCell(u_dim, g0_dim, bias=True)

        self.use_controller = use_controller
        if self.use_controller:
            self.con_cell = nn.GRUCell(con_input_dim + factor_dim, con_dim)
            self.con_h0 = nn.Parameter(torch.zeros(1, con_dim), requires_grad=True)

            self.con_u_mu = nn.Linear(con_dim, u_dim)
            self.con_u_log_var = nn.Linear(con_dim, u_dim)

            self.sample = Reparameterize()

        self.fac_linear = NormalizedLinear(g0_dim, factor_dim, bias=False)

        self.dropout = nn.Dropout(dropout_rate)
        self.clip_value = clip_value

    def forward(self, g0, con_input):
        T, B = con_input.size(0), con_input.size(1)

        g = [g0]
        f = [self.fac_linear(self.dropout(g0))]
        if self.use_controller:
            con_h = torch.tile(self.con_h0, (B, 1))
            u_mu = []
            u_log_var = []
        else:
            u_mu = torch.zeros(T, B, 0).to(con_input)
            u_log_var = torch.zeros(T, B, 0).to(con_input)
        
        for t in range(T):
            if self.use_controller:
                con_h = self.con_cell(self.dropout(torch.cat((con_input[t], f[-1]), dim=1)), con_h)
                con_h = torch.clamp(con_h, -self.clip_value, self.clip_value)

                u_mu.append(self.con_u_mu(con_h))
                u_log_var.append(self.con_u_log_var(con_h))

                u = self.sample(u_mu[-1], u_log_var[-1])
            else:
                u = u_mu[-1]
            
            g.append(self.gen_cell(u, g[-1]))
            g[-1] = torch.clamp(g[-1], -self.clip_value, self.clip_value)
            g_drop = self.dropout(g[-1])
            f.append(self.fac_linear(g_drop))

        g = torch.stack(g[1:], dim=0)
        f = torch.stack(f[1:], dim=0)
        if self.use_controller:
            u_mu = torch.stack(u_mu, dim=0)
            u_log_var = torch.stack(u_log_var, dim=0)

        return g, u_mu, u_log_var, f


class LFADS(nn.Module):
    def __init__(
        self, input_dim, encod_input_dim,
        factor_dim,
        g0_enc_dim, g0_dim,
        con_enc_dim, con_dim, u_dim,
        clip_value=float("inf"), dropout_rate=0.0, prior_g0_var=0.1, tau=10, nvar=0.1
    ):
        super().__init__()
        self.use_controller = con_enc_dim > 0 and con_dim > 0 and u_dim > 0

        if isinstance(input_dim, int):
            input_dim = [input_dim]
        if len(input_dim) == 1:
            assert input_dim[0] == encod_input_dim
            self.readin = nn.ModuleList([nn.Identity()])
            self.readout = nn.ModuleList([nn.Linear(factor_dim, input_dim[0])])
        else:
            self.readin = nn.ModuleList([])
            self.readout = nn.ModuleList([])
            for i in range(len(input_dim)):
                self.readin.append(nn.Linear(input_dim[i], encod_input_dim))
                self.readout.append(nn.Linear(factor_dim, input_dim[i]))
        
        self.encoder = Encoder(encod_input_dim, g0_enc_dim, g0_dim, con_enc_dim, clip_value=clip_value, dropout_rate=dropout_rate, use_controller=self.use_controller)
        self.sample_g0 = Reparameterize()
        self.decoder = Decoder(con_enc_dim * 2, con_dim, u_dim, g0_dim, factor_dim, clip_value=clip_value, dropout_rate=dropout_rate, use_controller=self.use_controller)

        self.dropout = nn.Dropout(dropout_rate)

        self.prior_g0_var = prior_g0_var
        self.log_tau = nn.Parameter(torch.log(torch.ones(u_dim) * tau), requires_grad=True)
        self.log_nvar = nn.Parameter(torch.log(torch.ones(u_dim) * nvar), requires_grad=True)

        self.init_weight()
    
    def forward(self, x, session=0):
        encod_input = self.readin[session](x)

        g0_mu, g0_log_var, con_input = self.encoder(encod_input)
        g0 = self.sample_g0(g0_mu, g0_log_var)
        g, u_mu, u_log_var, f = self.decoder(g0, con_input)
        r = torch.exp(self.readout[session](f))

        output = {
            "g0_mu": g0_mu,
            "g0_log_var": g0_log_var,
            "u_mu": u_mu,
            "u_log_var": u_log_var,
            "f": f,
            "r": r,
        }
        return output
    
    def compute_loss(self, **kwargs):
        x = kwargs["x"]
        r = kwargs["r"]
        reconstruction_loss = F.poisson_nll_loss(r, x, log_input=False, reduction="sum") / x.size(0) / x.size(1)

        g0_mu = kwargs["g0_mu"]
        g0_log_var = kwargs["g0_log_var"]
        kld_loss_g0 = 0.5 * torch.mean(torch.sum(-1 + (g0_mu ** 2 + torch.exp(g0_log_var)) / self.prior_g0_var - g0_log_var + math.log(self.prior_g0_var), dim=-1))

        if self.use_controller:
            u_mu = kwargs["u_mu"]
            u_log_var = kwargs["u_log_var"]

            u_sample = u_mu + torch.exp(0.5 * u_log_var) * torch.randn_like(u_mu)
            alpha = torch.exp(-1 / torch.exp(self.log_tau))
            log_pvar = self.log_nvar - torch.log(1 - alpha ** 2)
            prior_mu = torch.roll(u_sample, shifts=1, dims=0) * alpha
            prior_mu[0] = 0
            prior_log_var = torch.ones_like(prior_mu) * self.log_nvar
            prior_log_var[0] = log_pvar
            kld_loss_u = 0.5 * torch.mean(torch.sum(-1 + ((u_mu - prior_mu) ** 2 + torch.exp(u_log_var)) / torch.exp(prior_log_var) - u_log_var + prior_log_var, dim=-1))
        else:
            kld_loss_u = 0
        
        kld_weight = kwargs["kld_weight"]
        loss = reconstruction_loss + kld_weight * (kld_loss_g0 + kld_loss_u)
        if self.use_controller:
            return OrderedDict([("loss", loss), ("reconstruction_loss", reconstruction_loss.detach()), ("kld_loss_g0", kld_loss_g0.detach()), ("kld_loss_u", kld_loss_u.detach())])
        else:
            return OrderedDict([("loss", loss), ("reconstruction_loss", reconstruction_loss.detach()), ("kld_loss_g0", kld_loss_g0.detach())])
    
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GRUCell):
                nn.init.kaiming_normal_(m.weight_ih)
                nn.init.kaiming_normal_(m.weight_hh)
                if m.bias_ih is not None:
                    nn.init.constant_(m.bias_ih, 0)
                if m.bias_hh is not None:
                    nn.init.constant_(m.bias_hh, 0)


def lfads(**kwargs):
    return LFADS(**kwargs)
