import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils as nn_utils
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.categorical import Categorical
from torch.distributions import MultivariateNormal, Normal, Independent
from torchdiffeq import odeint_adjoint

from zuko.utils import odeint
from zuko.distributions import DiagNormal
from unet import *

from dataloader.dataloader_lds import *

torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)


class ShiftedTanh(torch.nn.Module):
    def __init__(self, a=1.0, b=1.0):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return torch.tanh(x) + self.b

def log_normal(x: Tensor) -> Tensor:
    return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2

def first_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = x - x.mean(0, keepdims=True)
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Identify the first Eigen direction
    first_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]

    # Step 4: Project the array onto the first Eigen direction
    projected_array = np.dot(x, first_eigenvector)

    return projected_array


class SoftplusParameterization(nn.Module):
    def forward(self, X):
        # return torch.exp(X)
        return nn.functional.softplus(X)

class ScaledSigmoid(nn.Module):
    def __init__(self):
        super(ScaledSigmoid, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return 2 * self.sigmoid(x) - 1

def scaledsigmoid(x):
    return torch.sigmoid(x) * 2 - 1

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False,
        dropout=False,
        weight_norm=False,
        layer_norm=False,
        p=0.2
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):  
            linear_layer = nn.Linear(a, b)
            if weight_norm:
                linear_layer = nn_utils.weight_norm(linear_layer)
            elif batch_norm:
                layers.extend([linear_layer, nn.BatchNorm1d(b), fct])
            elif layer_norm:
                layers.extend([linear_layer, nn.LayerNorm(b), fct])
            elif dropout:
                layers.extend([linear_layer, nn.Dropout(p=p), fct])
            else:
                layers.extend([linear_layer, fct])

        if not weight_norm or batch_norm or layer_norm or dropout:
            super().__init__(*layers[:-1])
        else:
            super().__init__(*layers[:-2])

class weightConstraint(object):
    def __init__(self):
        pass
    
    def __call__(self,module):
        if hasattr(module,'weight'):
            print("Entered")
            w=module.weight.data
            w=w.clamp(-1,1)
            module.weight.data=w


class LatentDynamicalSystem(nn.Module):
    def __init__(self, latent_dim):
        """
        Initialize the latent dynamical system prior.
        
        Parameters:
        - latent_dim: Dimensionality of the latent space.
        """
        super(LatentDynamicalSystem, self).__init__()
        self.latent_dim = latent_dim
        self.mu1 = torch.zeros(latent_dim)  # Mean vector for z1
        self.log_Q1 = torch.log(torch.ones(latent_dim)*0.5) # torch.zeros(latent_dim)  # Log-diagonal of covariance matrix Q1
        self.A = torch.eye(latent_dim)  # Transition matrix A
        self.log_Q = torch.log(torch.ones(latent_dim)*0.5)  # Log-diagonal of covariance matrix Q for subsequent steps


    def _get_distribution(self, mu, log_var):
        """Create a diagonal Gaussian distribution given mean and log variance."""
        # restrict mu to (-1,1), restrict log_var to negative

        scale = torch.diag_embed(torch.exp(0.5 * log_var))
        dists = MultivariateNormal(mu, scale_tril=scale)
        return dists

    def rsample(self, n, S, device):
        """
        Generate a sequence of latent variables z_1, ..., z_T using reparameterization trick.
        Parameters:
        - T: Number of time steps.
        Returns:
        - zs: Sequence of latent variables (shape: [T, latent_dim])
        """
        zs = torch.zeros((S, n, self.latent_dim), device=device)
        # First latent variable z1 ~ N(mu1, Q1)
        dist_z1 = self._get_distribution(self.mu1.to(device), self.log_Q1.to(device))
        zs[0] = dist_z1.rsample((n,))  # Use rsample for reparameterization
        # Latent variables z_s+1 ~ N(A * z_s, Q)
        for s in range(1, S):
            dist_zs = self._get_distribution(
                torch.einsum("ij, nj -> ni", self.A.to(device), zs[s-1].clone()), self.log_Q.to(device))
            # print(dist_zs.rsample().shape)
            zs[s] = dist_zs.rsample()  # Use rsample for reparameterization
        
        return zs # (S, n, latent_dim)

    def sample(self, n, S, device):
        """
        Generate a sequence of latent variables z_1, ..., z_T using reparameterization trick.
        Parameters:
        - T: Number of time steps.
        Returns:
        - zs: Sequence of latent variables (shape: [T, latent_dim])
        """
        zs = torch.zeros((S, n, self.latent_dim)).to(device)
        # First latent variable z1 ~ N(mu1, Q1)
        dist_z1 = self._get_distribution(self.mu1.to(device), self.log_Q1.to(device))
        zs[0] = dist_z1.sample((n,)).to(device) 
        # Latent variables z_s+1 ~ N(A * z_s, Q)
        for s in range(1, S):
            mus = self.A.to(device).expand(n,self.latent_dim,self.latent_dim) @ zs[s-1].unsqueeze(-1)
            dist_zs = self._get_distribution(
                mus.squeeze(), 
                self.log_Q.to(device))
            zs[s] = dist_zs.sample().to(device)  # Use rsample for reparameterization

        return zs # (S, n, latent_dim)

    def _log_prob(self, z_s, z_s_minus_1):
        """
        Evaluate the log probability of z_s given z_{s-1}.
        
        Parameters:
        - z_s: The current latent variable z_s (shape: [latent_dim])
        - z_s_minus_1: The previous latent variable z_{s-1} (shape: [latent_dim])
        
        Returns:
        - log_prob: Log probability of z_s given z_{s-1}.
        """
        if z_s_minus_1 is not None:
            # Compute the means A * z_{s-1} for all steps in parallel
            means = torch.einsum('ij,ni->nj', self.A, z_s_minus_1)  # Shape: [n, T-1, latent_dim]
            # Get the conditional distribution P(z_s | z_{s-1}) for all steps
            dist_zs = self._get_distribution(means, self.log_Q)
            # Compute log-probability of z_s under P(z_s | z_{s-1})
            log_prob = dist_zs.log_prob(z_s) # [n,]
        else:
            dist_zs = self._get_distribution(self.mu1, self.log_Q1)
            log_prob = dist_zs.log_prob(z_s)

        return log_prob.squeeze()


class sRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, S, freqs:int =2, nonlinearity="tanh", bidirectional=True):
        super(sRNN, self).__init__()
        self.S = S
        self.freqs = freqs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.positional_encoding = timestep_embedding
        self.bidirectional = bidirectional
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity=nonlinearity, bidirectional=bidirectional, batch_first=False)
        self.skipnn = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        # Add positional encoding
        # print("x", x.shape)
        # print("positional_encoding", self.positional_encoding.shape)
        semb = self.positional_encoding(torch.arange(self.S).to(x.device), self.input_size, self.freqs).unsqueeze(1)
        x = x + semb
        batch_size = x.shape[1]
        if self.bidirectional:
            # Initial hidden state
            h0 = torch.zeros(2*self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.rnn(x, h0)
            # compute embeded x
            emb = torch.concatenate([hs[0], hs[1]], dim=-1).view(batch_size, self.hidden_size*2)
        else:
            # Initial hidden state
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.rnn(x, h0)
            # compute embeded x
            emb = hs[-1].view(batch_size, self.hidden_size)
        # print("emb", emb.squeeze().shape)
        return emb

    def recurse(self, xs, hsminus):
        batch_size = xs.shape[0]
        # with new sample coming in compute the new embedding
        if not self.bidirectional:
            if hsminus is not None:
                _, hs = self.rnn(xs.unsqueeze(0), hsminus.unsqueeze(0))
                emb = hs[-1] + self.skipnn(hsminus)
            else:
                _, hs = self.rnn(xs.unsqueeze(0), hsminus)
                emb = hs[-1] # .view(-1, self.hidden_size)

        return emb

class sGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, S, F, dsemb:int =8, bidirectional=True, activation="tanh"):
        super(sGRU, self).__init__()
        self.S = S
        self.F = F
        self.dsemb = dsemb
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.positional_encoding = timestep_embedding
        self.bidirectional = bidirectional
        self.gru = nn.GRU(input_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=False)

    def forward(self, x, indices=None):
        batch_size = x.shape[1]
        if self.bidirectional:
            # Initial hidden state
            h0 = torch.zeros(2*self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            out, hs = self.gru(x, h0)
            # compute embeded x
            emb = torch.concatenate([hs[0], hs[1]], dim=-1).view(batch_size, self.hidden_size*2)
        else:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            out, hs = self.gru(x, h0)
            emb = hs
        # print("emb", emb.squeeze().shape)
        return out, emb

    def recurse(self, xs, hsminus, s):
        batch_size = xs.shape[0]

        xs = xs.unsqueeze(0).clone()

        # with new sample coming in compute the new embedding
        if not self.bidirectional:
            if hsminus is not None:
                _, hs = self.gru(xs, hsminus)
                # print("hs", hs.shape)
                emb = hs # + hsminus
            else:
                _, hs = self.gru(xs, hsminus)
                emb = hs # .view(-1, self.hidden_size)
            # _, hs = self.gru(xs, hsminus)
        return emb

class sGRU_CNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, S, F, bidirectional=True, emb_dim=64):
        super(sGRU_CNN, self).__init__()
        self.S = S
        self.F = F
        # self.freqs = freqs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # self.positional_encoding = timestep_embedding
        self.bidirectional = bidirectional
        self.emb_dim = emb_dim # self.hidden_size
        self.gru = nn.GRU(self.emb_dim, hidden_size, num_layers, bidirectional=bidirectional, batch_first=False)
        # CNN Feature Extractor
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # (B, 16, 10, 10)
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # (B, 32, 10, 10)
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # (B, 32, 1, 1)
        )
        
        # Fully connected layer to convert CNN output to embedding vector
        self.fc = nn.Linear(32, self.emb_dim)  # Maps CNN features to embedding_dim

    def forward(self, x, indices=None):
        # Add positional encoding
        seq_len, batch_size, c, h, w = x.shape  # (T, B, C, H, W)

        if indices is None:
            indices = torch.arange(seq_len, device=x.device)/seq_len
        # Process each frame through CNN
        x = x.view(batch_size * seq_len, c, h, w)  # Merge batch & time dimensions
        cnn_features = self.cnn(x)  # (B*T, 32, 1, 1)
        cnn_features = cnn_features.view(batch_size * seq_len, -1)  # Flatten (B*T, 32)
        x_emb = self.fc(cnn_features).view(seq_len, batch_size, -1)  # (T, B, embedding_dim)

        if self.bidirectional:
            # Initial hidden state
            h0 = torch.zeros(2*self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.gru(x_emb, h0)
            # compute embeded x
            emb = torch.concatenate([hs[0], hs[1]], dim=-1).view(batch_size, self.hidden_size*2)
        else:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
            # Forward propagate through RNN
            _, hs = self.gru(x_emb, h0)
            # compute embeded x
            emb = hs #[-1].view(batch_size, self.hidden_size)
        return emb.view(batch_size, self.hidden_size*self.num_layers)


class LLK(nn.Module):
    def __init__(self, x_features: int, z_features: int, S: int, F:int, num_hidden, dsemb=2, freqs: int = 2, in_ch=1, **kwargs):
        super().__init__()
        self.S = S
        self.F=F
        self.dtemb = 2 * freqs
        self.dsemb = dsemb
        self.num_hidden = num_hidden # 64 # x_features

        self.embz = nn.Linear(z_features, self.num_hidden)
        self.embx = MLP(x_features**2*in_ch, self.num_hidden, hidden_features=[128], fct=nn.Softplus())
        # self.embt = nn.Linear(self.dtemb, self.num_hidden)

        self.fc1 = nn.Sequential(
            MLP(2*self.num_hidden+self.dtemb+self.dsemb, x_features**2*in_ch, **kwargs),
            )
        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(S)), self.dsemb // 2))
        self.register_buffer('freqs_s', 2 * torch.pi / self.F * freqs_s)
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)


    def forward(self, t: Tensor, x: Tensor, z:Tensor,s:Tensor) -> Tensor:
        nbatch, p = x.shape[0], x.shape[-1]
        x = x.flatten(start_dim=1)
        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        zemb = self.embz(z)
        xemb = self.embx(x)

        s = self.freqs_s * s[..., None]
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*xemb.shape[:-1], -1)
        out = self.fc1(torch.cat((temb, zemb, xemb, semb), dim=-1))

        return out.view(nbatch, 1, p, p)

    def _forward(self, t: Tensor, x: Tensor, z:Tensor, s:Tensor) -> Tensor:
        out = self.forward(t, x, z, s)
        return out, out

    def decode(self, x: Tensor, z: Tensor, t=None, s=0) -> Tensor:
        output_shp = x.shape
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        xt = odeint_adjoint(
            partial(self, z=z, s=s), 
            x, 
            torch.tensor([t, 0.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decodeS(self, xS: Tensor, zS: Tensor, t=None, indices=None) -> Tensor:

        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S

        if t is None:
            t = 1.
        x0S = xS.clone().detach()
        for s in range(self.S):
            # skip the 0 index of z, all zeros
            x0S[s] = self.decode(xS[s], zS[s], t=t, s=indices[s])
        return x0S


    def log_prob(self, x: Tensor, z: Tensor, s:Tensor) -> Tensor:
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        def augmented(t: Tensor, state) -> Tensor:
            x, ladj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z, s)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, (x, ladj), torch.tensor([0.,1.], device=x.device), adjoint_params=self.parameters())

        return log_normal(x0[-1]) + ladj[-1] * 1e2

    def log_probS(self, xS:Tensor, zS:Tensor, indices=None):
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S
        out = 0
        for s in range(self.S):
            out += self.log_prob(xS[s], zS[s], indices[s])
        return out

def sum_except_batch(x):
    return x.view(x.size(0), -1).sum(-1)

class LLKcnn(nn.Module):
    def __init__(self, x_features: int, z_features: int, S: int, F: int, num_hidden=16, freqs: int = 4, in_ch=1, **kwargs):
        super().__init__()
        self.S = S
        self.F = F
        self.dtemb = 2 * freqs
        self.num_hidden = num_hidden # 256 # x_features
        # self.num_hidden_z = num_hidden # 256 # x_features
        # self.dsemb = 4
        self.embz = nn.Sequential(
            nn.Linear(z_features, self.num_hidden), 
            # ShiftedTanh()
            )

        self.cnn = nn.Sequential(
            nn.Conv2d(in_ch, 16, kernel_size=3, stride=1, padding=1), # Maintains size
            nn.Softplus(), # Activation function
            nn.MaxPool2d(kernel_size=2, stride=2), # Halves spatial dimensions
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), # Maintains size
            nn.Softplus(),
            nn.MaxPool2d(kernel_size=2, stride=2), # Halves spatial dimensions
            nn.Flatten(),
            nn.Linear(128, self.num_hidden)
        )

        self.fc = MLP(2*self.num_hidden+2*freqs, x_features**2*in_ch, **kwargs)

        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        n, p = x.shape[0], x.shape[-1]

        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:1], -1)
        # temb = self.embt(temb)
        zemb = self.embz(z)
        xemb = self.cnn(x)     

        out = self.fc(torch.cat((xemb, temb, zemb), dim=1))

        return out.view(n, 1, p, p) # out

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out


    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        xt = odeint_adjoint(
            partial(self, z=z), 
            x, 
            torch.tensor([t, 0.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decodeS(self, xS: Tensor, zS: Tensor, t=None) -> Tensor:
        # xS of shape (S, n, xdim)
        # zS of shape (S, n, zdim)
        
        if t is None:
            t = 1.
        x0S = xS.clone().detach()
        for s in range(self.S):
            # skip the 0 index of z, all zeros
            x0S[s] = self.decode(xS[s], zS[s], t=t)
        return x0S

    @staticmethod
    def hutch_trace(f, y, e):
        """Hutchinson's estimator for the Jacobian trace"""
        # With _eps ~ Rademacher (== Bernoulli on -1 +1 with 50/50 chance).
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = sum_except_batch(e_dzdx_e)
        return approx_tr_dzdx

    def log_prob(self, x: Tensor, z: Tensor, source) -> Tensor:
        z = z.clone().detach().requires_grad_(True)
        # x = x.clone().detach().requires_grad_(True)
        e = torch.randint(low=0, high=2, size=x.size()).to(x.device) * 2 - 1
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x.requires_grad_(True)
                dx = self(t, x, z)
                trace = self.hutch_trace(dx,x,e)
            return dx, trace * 1e-2

        # ladj = torch.zeros_like(x[..., 0])
        ladj = x.new_zeros(x.shape[0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([0., 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)

        return source.log_prob(x0[-1]).sum(dim=(1, 2, 3)) + ladj[-1] * 1e2

    def log_probS(self, xS:Tensor, zS:Tensor, source):
        out = 0
        for s in range(self.S):
            out += self.log_prob(xS[s], zS[s], source)
        return out


class LLKc(nn.Module):
    def __init__(self, x_features: int, z_features: int, S: int, F:int, num_hidden, dsemb=2, freqs: int = 2, in_ch=1, **kwargs):
        super().__init__()
        self.S = S
        self.F=F
        self.dtemb = 2 * freqs
        self.dsemb = dsemb
        self.num_hidden = num_hidden # 64 # x_features

        self.embz = nn.Linear(z_features, self.num_hidden)

        self.fc1 = MLP(self.num_hidden+x_features**2*in_ch+self.dtemb, x_features**2*in_ch, **kwargs)
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)


    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        nbatch, p = x.shape[0], x.shape[-1]
        x = x.flatten(start_dim=1)
        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        zemb = self.embz(z)
        out = self.fc1(torch.cat((temb, zemb, x), dim=-1))

        return out.view(nbatch, 1, p, p)

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        output_shp = x.shape
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        xt = odeint_adjoint(
            partial(self, z=z), 
            x, 
            torch.tensor([t, 0.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decodeS(self, xS: Tensor, zS: Tensor, t=None) -> Tensor:
        # xS of shape (S, n, xdim)
        # zS of shape (S, n, zdim)
        S = xS.shape[0]
        if t is None:
            t = 1.
        x0S = xS.clone().detach()
        for s in range(S):
            # skip the 0 index of z, all zeros
            x0S[s] = self.decode(xS[s], zS[s], t=t)
        return x0S


    def log_prob(self, x: Tensor, z: Tensor) -> Tensor:
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        def augmented(t: Tensor, state) -> Tensor:
            x, ladj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, (x, ladj), torch.tensor([0.,1.], device=x.device), adjoint_params=self.parameters())

        return log_normal(x0[-1]) + ladj[-1] * 1e2

    def log_probS(self, xS:Tensor, zS:Tensor):
        out = 0
        for s in range(self.S):
            out += self.log_prob(xS[s], zS[s])
        return out


class CNF(nn.Module):
    def __init__(self, x_features: int, z_features: int, S: int, dsemb: int = 10, freqs: int = 2, nonlinearity="tanh",  **kwargs):
        super().__init__()
        self.S = S
        self.dsemb = dsemb
        self.dtemb = 2 * freqs
        self.num_hidden = 2 * z_features # because of bidirectional RNN
        self.emb = sRNN(x_features, self.num_hidden, S=self.S, num_layers=1, freqs=freqs, nonlinearity=nonlinearity) 
        self.fc1 = nn.Sequential(
            MLP(self.dtemb + self.dsemb + z_features + self.num_hidden*2, z_features, **kwargs),
            )
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)
        
        # Define frequency range (logarithmic scaling)
        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(S)), self.dsemb // 2))
        self.register_buffer('freqs_s', 2 * torch.pi / S * freqs_s)

    def forward(self, t: Tensor, z:Tensor, x: Tensor, s:Tensor) -> Tensor:
        # z = zs, zc = zs-1,

        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)

        s = self.freqs_s * s[..., None]
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*z.shape[:-1], -1)

        xemb = self.emb(x)
        out = self.fc1(torch.cat((temb, xemb, z, semb), dim=-1))
        return -out 

    def decode(self, z: Tensor, x: Tensor, s:Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, x=x, s=s), 
            z, 
            torch.tensor([1., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def decodeS(self, zS: Tensor, xS: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        z0S = zS.clone().detach()
        for s in range(self.S):
            z0S[s] = self.decode(zS[s], xS, torch.tensor([s], dtype=torch.long, device=xS.device), t)
        return z0S

    # zc should be samples from the smoothing probability
    def log_prob(self, z: Tensor, zc: Tensor, x: Tensor, s:int, t, prior) -> Tensor:
        if t is None:
            t = 0.
        I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
        I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)

        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = -self(t, z, x, s)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0])
        z0, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)

        # is this true? our prior is over the entire state-space
        return prior._log_prob(z0[-1], zc) + ladj[-1] * 1e2

    def log_probS(self, zS, xS, t, prior):
        out = self.log_prob(zS[0], None, xS, torch.tensor([0], dtype=torch.long, device=xS.device), t, prior)

        for s in range(1,self.S):
            out += self.log_prob(zS[s], zS[s-1], xS, torch.tensor([s], dtype=torch.long, device=xS.device), t, prior)
        return out

class AttentionPooling(nn.Module):
    """
    Implements an Attention Pooling layer.

    This layer takes a sequence of vectors (with sequence length as the first dimension)
    and computes a single fixed-length context vector by learning attention
    weights for each vector in the sequence.
    """
    def __init__(self, input_dim, attention_dim=None):
        """
        Initializes the AttentionPooling layer.

        Args:
            input_dim (int): Dimensionality of the input vectors
                             (e.g., hidden_size of an RNN or feature dimension
                              of embeddings).
            attention_dim (int, optional): Dimensionality of the attention
                                         mechanism's hidden layer.
                                         If None, it defaults to input_dim // 2.
                                         This can be tuned as a hyperparameter.
        """
        super(AttentionPooling, self).__init__()
        self.input_dim = input_dim

        if attention_dim is None:
            # Heuristic for attention_dim, can be tuned
            attention_dim = input_dim // 2
            if attention_dim == 0: # Ensure attention_dim is at least 1
                attention_dim = 1

        # Layer to compute the intermediate attention representation
        # This projects input vectors to an "attention space"
        self.attention_layer = nn.Linear(input_dim, attention_dim)

        # Layer to compute the attention scores (energy) from the intermediate representation
        # This projects the attention_dim representation to a single scalar score
        self.context_vector_layer = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, sequence_outputs):
        """
        Forward pass of the AttentionPooling layer.

        Args:
            sequence_outputs (torch.Tensor): The sequence of input vectors.
                                            Expected shape: (seq_len, batch_size, input_dim)

        Returns:
            torch.Tensor: The fixed-length context vector.
                          Shape: (batch_size, input_dim)
            torch.Tensor: The attention weights.
                          Shape: (batch_size, seq_len)
        """
        # Get the original dimensions from the input
        seq_len, batch_size, feat_dim = sequence_outputs.shape

        # Check if input_dim matches the feature dimension of the input
        if feat_dim != self.input_dim:
            raise ValueError(
                f"Input feature dimension ({feat_dim}) does not match "
                f"layer's input_dim ({self.input_dim})"
            )
        proc_sequence_outputs = sequence_outputs.permute(1, 0, 2)
        reshaped_outputs = proc_sequence_outputs.reshape(-1, self.input_dim)
        attn_hidden = torch.tanh(self.attention_layer(reshaped_outputs))

        energy = self.context_vector_layer(attn_hidden)
        energy = energy.reshape(batch_size, seq_len)

        # Shape: (batch_size, seq_len)
        attention_weights = F.softmax(energy, dim=1)

        expanded_weights = attention_weights.unsqueeze(-1)
        weighted_outputs = proc_sequence_outputs * expanded_weights
        context_vector = torch.sum(weighted_outputs, dim=1)

        return context_vector, attention_weights


class fullGauss(nn.Module):
    def __init__(
        self, x_features: int, z_features: int, S: int, F:int, 
        dsemb: int = 4, freqs: int = 2, num_hidden=32, nonlinearity="tanh", 
        cnn=False, num_layers=2,emb_dim=32, in_ch=1, rnn=False, attention=False, **kwargs
        ):
        super().__init__()
        self.S = S # subsequence length
        self.F = F # total length
        self.num_hidden = emb_dim # x_features
        self.num_hidden_z = num_hidden # sufficient to capture a lower dimensional dynamics
        self.num_layers = num_layers
        self.dsemb = dsemb # z_features * 2 # dsemb
        self.dtemb = 2 * freqs
        self.emb_dim = emb_dim
        self.in_ch = in_ch
        self.cnn = cnn
        self.rnn = rnn
        self.attention = attention
        # self.k = 1
        self.shiftedtanh = ShiftedTanh()
        
        self.x_features = x_features
        self.z_features = z_features

        if attention:
            # print("use attention pooling")
            if cnn:
                self.embx = nn.Sequential(
                    nn.Conv2d(in_ch, 16, kernel_size=3, stride=1, padding=1), # Maintains size
                    nn.Softplus(), # Activation function
                    nn.MaxPool2d(kernel_size=2, stride=2), # Halves spatial dimensions
                    nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), # Maintains size
                    nn.Softplus(),
                    nn.MaxPool2d(kernel_size=2, stride=2), # Halves spatial dimensions
                    nn.Flatten(),
                    nn.Linear(128, self.num_hidden)
                )
            else:
            # self.embx = sGRU(x_features, self.num_hidden, num_layers=1, S=self.S, F=self.F, dsemb=dsemb, bidirectional=False)
                self.embx = nn.Linear(x_features, self.num_hidden)
            self.att = AttentionPooling(self.num_hidden)
            self.num_embx = self.num_hidden
        elif rnn:
            # print("rnn")
            self.embx = sGRU(x_features, self.num_hidden, num_layers=1, S=self.S, F=self.F, dsemb=dsemb, bidirectional=False)
            self.num_embx = self.num_hidden
        else:
            self.embx = nn.Sequential(nn.Linear(x_features, self.num_hidden))
            self.num_embx = self.S*self.num_hidden

        self.rnnz = sGRU(z_features, self.num_hidden_z, self.num_layers, S=self.S, F=self.F, dsemb=dsemb, bidirectional=False)

        # Calculate the number of elements for the lower-triangular matrix L
        self.num_l_elements = self.z_features * (self.z_features + 1) // 2
        # Total output dimension for MLP: mean (z_features) + L_elements (num_l_elements)
        mlp_output_dim = self.z_features + self.num_l_elements
        # mlp_output_dim = self.z_features *2
        mlp_input_dim = self.num_embx+self.num_hidden_z*self.num_layers+self.dtemb+self.dsemb
        self.fc = nn.Sequential(
            MLP(mlp_input_dim, mlp_output_dim, **kwargs),
            )
    
        self.semb_project = nn.Linear(self.dsemb, self.num_hidden)
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)
        
        # Define frequency range (logarithmic scaling)
        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(self.F)), self.dsemb // 2))
        self.register_buffer('freqs_s', 2 * torch.pi / self.F * freqs_s)

        diag_indices = torch.arange(self.z_features)
        self.register_buffer('diag_indices', diag_indices)

        off_tril_indices = torch.tril_indices(row=self.z_features, col=self.z_features, offset=-1)
        self.register_buffer('off_tril_indices_row', off_tril_indices[0])
        self.register_buffer('off_tril_indices_col', off_tril_indices[1])

        

    def forward(self, t: Tensor, zcemb:Tensor, xcemb: Tensor, s:Tensor, min_variance=1e-6) -> Tensor:
        # z = zs, zc = zs-1,
        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*xcemb.shape[:-1], -1)

        s = self.freqs_s * s[..., None]
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        # zemb = self.projz(z)
        semb = semb.expand(*xcemb.shape[:-1], -1)
        # semb = self.semb_project2(semb)

        # print("semb", semb.shape)
        if zcemb is None:
            # inp = torch.cat((temb, semb, torch.zeros((len(xcemb),self.num_hidden*self.S)).to(device=xcemb.device), xcemb), dim=-1)
            inp = torch.cat((temb, semb, torch.zeros((len(xcemb),self.num_hidden_z*self.num_layers)).to(device=xcemb.device), xcemb), dim=-1)
            phi = self.fc(inp)
            # skip_out = self.skip_proj(inp)
        else:
            # zcemb = self.proj(zcemb)
            inp = torch.cat((temb, semb, zcemb, xcemb), dim=-1)
            phi = self.fc(inp)

        # Split phi into mu, raw_diag_elements, and raw_off_diag_elements
        mu = phi[:, :self.z_features]
        raw_diag_elements = phi[:, self.z_features : self.z_features * 2]
        raw_off_diag_elements = phi[:, self.z_features * 2:]
        # Initialize L matrix (scale_tril) with zeros
        L = torch.zeros(len(xcemb), self.z_features, self.z_features, device=phi.device)
        
        # Fill the strictly off-diagonal lower-triangular part of L
        L[:, self.off_tril_indices_row, self.off_tril_indices_col] = raw_off_diag_elements
        
        # Transform and fill the diagonal elements of L
        sqrt_min_var = torch.sqrt(torch.tensor(min_variance, device=phi.device))
        positive_diag_elements = F.softplus(raw_diag_elements) + sqrt_min_var
        L[:, self.diag_indices, self.diag_indices] = positive_diag_elements # Overwrite diagonal
        
        return MultivariateNormal(loc=mu, scale_tril=L)



    def _add_semb(self, x, indices): # indices of shape (n,1)
        s = self.freqs_s.unsqueeze(0).unsqueeze(0) * indices.unsqueeze(-1)
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        return x + self.semb_project(semb)

    def decode(self, zcemb: Tensor, xcemb: Tensor, s:Tensor, t=None) -> Tensor:
        if t is None:
            t = torch.Tensor([0.])[0]
        zt = self(t, zcemb, xcemb, s).rsample()
        return zt

    def decodeS(self, zS:Tensor, xS: Tensor, t=None, indices=None) -> Tensor:
        if t is None:
            t = torch.Tensor([0.])[0]
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S

        z0S = zS.clone().detach()
        # embed xS
        zSemb = [None]
        nbatch = xS.shape[1]
        if self.rnn:
            _, x_emb = self.embx(xS)
        if self.attention:
            xS = xS.view(nbatch * self.S, *xS.shape[2:])
            x_emb = self.embx(xS).view(self.S, nbatch, self.num_hidden)
            # # print("x_emb", self.num_hidden)
            x_emb, _ = self.att(x_emb)
            # x_emb, _ = self.att(xS)
        else:
            xS = xS.view(nbatch * self.S, self.x_features)
            x_emb = self.embx(xS).view(self.S, nbatch, self.num_hidden)
        # compute semb and add to x_emb
        # xcemb = self._add_semb(x_emb, indices).view(nbatch, self.num_embx)
        xcemb = x_emb.view(nbatch, self.num_embx)

        z0S[0] = self.decode(zSemb[-1], xcemb, indices[0], t)
        # z0S[0] = z0S0 / torch.norm(z0S0, dim=-1, keepdim=True)
        for s in range(1,self.S):
            # update hidden state and append
            zcemb = self.rnnz.recurse(z0S[s-1], zSemb[-1], indices[s])
            # print(zcemb.shape)
            zSemb.append(zcemb)
            z0S[s] = self.decode(zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[s], t)
            # z0S[s] = z0Ss / torch.norm(z0Ss, dim=-1, keepdim=True)
        return z0S

    def predictF(self, F, zF, xS, t=None, indices=None): # F>S
        nbatch = xS.shape[1]
        if t is None:
            t = torch.Tensor([0.])[0]
        if indices is None:
            indices = torch.arange(F, device=xS.device)/F
        zFemb = [None]
        if self.rnn:
            _, x_emb = self.embx(xS)
        if self.attention:
            x_emb, _ = self.att(xS)
        else:
            xS = xS.view(nbatch * self.S, self.x_features)
            x_emb = self.embx(xS).view(self.S, nbatch, self.num_hidden)

        xcemb = x_emb.view(nbatch, self.num_embx)

        z0F = zF.clone().detach()
        z0F[0] = self.decode(zFemb[-1], xcemb, indices[0], t)
        for f in range(1,F):
            zcemb = self.rnnz.recurse(z0F[f-1], zFemb[-1], indices[f])
            zFemb.append(zcemb)
            z0F[f] = self.decode(zFemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[f], t)

        return z0F

    def log_prob(self, z: Tensor, zcemb: Tensor, xcemb: Tensor, s:int, t, prior) -> Tensor:
        if t is None:
            t = torch.Tensor([0.])[0]
        return self(t, zcemb, xcemb, s).log_prob(z)

    def log_probS(self, zS, xS, t, prior, indices=None):
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S
        zSemb = [None]
        xcemb = self.rnnx(xS, indices=indices).view(-1, self.num_hidden*self.num_layers)
        # xcemb = self.embx(xSemb)

        out = self.log_prob(zS[0], zSemb[-1], xcemb, indices[0], t, prior)

        for s in range(1,self.S):
            zcemb = self.rnnz.recurse(zS[s-1], zSemb[-1], indices[s])
            zSemb.append(zcemb)
            out += self.log_prob(zS[s], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[s], t, prior)
        return out


class fullCNF(nn.Module):
    def __init__(self, x_features: int, z_features: int, S: int, F:int, dsemb: int = 4, freqs: int = 2, num_hidden=32, nonlinearity="tanh", cnn=False, num_layers=2, emb_dim=32, in_ch=1, **kwargs):
        super().__init__()
        self.S = S # subsequence length
        self.F = F # total length
        self.num_hidden = emb_dim # x_features
        self.num_hidden_z = num_hidden # sufficient to capture a lower dimensional dynamics
        self.hdim = 128
        self.dsemb = dsemb # z_features * 2 # dsemb
        self.dtemb = 2 * freqs
        self.num_layers = num_layers
        self.emb_dim = emb_dim
        self.cnn = cnn
        self.in_ch = in_ch
        
        self.hid_features = x_features
        self.x_features = x_features
        self.z_features = z_features
        if not cnn:
            self.embx = nn.Linear(x_features, self.num_hidden)
        else:
            self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # Output: 16 x 7 x 7
            nn.Softplus(),
            nn.MaxPool2d(2, 2),  # Output: 16 x 3 x 3
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 3 x 3
            nn.Softplus(),
            nn.MaxPool2d(2, 2),  # Output: 32 x 1 x 1
            )
            
            # Fully connected layer to convert CNN output to embedding vector
            self.fc_cnn = nn.Linear(32*2*2, self.num_hidden)
        

        self.rnnz = sGRU(z_features, self.num_hidden_z, self.num_layers, S=self.S, F=self.F, dsemb=dsemb, bidirectional=False)

        self.fc = MLP(self.S*self.num_hidden+self.num_hidden_z*self.num_layers+self.z_features+self.dsemb+self.dtemb, z_features, **kwargs)

        self.semb_project = nn.Linear(self.dsemb, self.num_hidden)
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)
        
        # Define frequency range (logarithmic scaling)
        freqs_s = torch.exp(torch.linspace(0, torch.log(torch.tensor(self.F)), self.dsemb // 2))

        self.register_buffer('freqs_s', 2 * torch.pi / self.F * freqs_s)

    def forward(self, t: Tensor, z:Tensor, zcemb:Tensor, xcemb: Tensor, s:Tensor) -> Tensor:
        # z = zs, zc = zs-1,
        # embed t
        t = self.freqs_t * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)

        s = self.freqs_s * s[..., None]
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        semb = semb.expand(*z.shape[:-1], -1)
        if zcemb is None:
           
            inp = torch.cat((z, semb, temb, xcemb, torch.zeros((len(xcemb),self.num_hidden_z*self.num_layers)).to(device=xcemb.device)), dim=-1)
            out = self.fc(inp)
        else:
            inp = torch.cat((z, semb, temb, xcemb, zcemb), dim=-1)
            out = self.fc(inp)
        return - out

    def _add_semb(self, x, indices): # indices of shape (n,1)
        s = self.freqs_s.unsqueeze(0).unsqueeze(0) * indices.unsqueeze(-1)
        # print("s", s.shape)
        semb = torch.cat((s.cos(), s.sin()), dim=-1)
        return x + self.semb_project(semb)

    def decode(self, z: Tensor, zcemb: Tensor, xcemb: Tensor, s:Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        xcemb = xcemb.clone().detach().requires_grad_(True)
        if zcemb is not None:
            zcemb = zcemb.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, zcemb=zcemb, xcemb=xcemb, s=s), 
            z, 
            torch.tensor([1., t], device=xcemb.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def decodeS(self, zS: Tensor, xS: Tensor, t=None, indices=None) -> Tensor:
        if t is None:
            t = 0.
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S

        z0S = zS.clone().detach()
        # embed xS
        zSemb = [None]
        nbatch = xS.shape[1]
        if self.cnn:
            xS = xS.view(nbatch * self.S, self.in_ch, self.x_features, self.x_features)  # Merge batch & time dimensions
            cnn_features = self.cnn(xS)  # (B*T, 32, 1, 1)
            cnn_features = cnn_features.view(nbatch * self.S, -1)  # Flatten (B*T, 32)
            x_emb = self.fc_cnn(cnn_features).view(self.S, nbatch, self.num_hidden)  # (T, B, embedding_dim)
        else:
            xS = xS.view(nbatch * self.S, self.x_features)
            x_emb = self.embx(xS).view(self.S, nbatch, self.num_hidden)
        # compute semb and add to x_emb
        xcemb = self._add_semb(x_emb, indices).view(nbatch, self.S*self.num_hidden)
        
        z0S[0] = self.decode(zS[0], zSemb[-1], xcemb, indices[0], t)
        for s in range(1,self.S):
            # update hidden state and append
            zcemb = self.rnnz.recurse(z0S[s-1], zSemb[-1], indices[s])
            # print(zcemb.shape)
            zSemb.append(zcemb)
            z0S[s] = self.decode(zS[s], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[s], t)
        return z0S

    def forwardS(self, ztS: Tensor, xS: Tensor, t=None, indices=None):
        if t is None:
            t = 0.
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S

        rt = 0
        zSemb = [None]
        nbatch = xS.shape[1]

        if self.cnn:
            xS = xS.view(nbatch * self.S, self.in_ch, self.x_features, self.x_features)  # Merge batch & time dimensions
            cnn_features = self.cnn(xS)  # (B*T, 32, 1, 1)
            cnn_features = cnn_features.view(nbatch * self.S, -1)  # Flatten (B*T, 32)
            x_emb = self.fc_cnn(cnn_features).view(self.S, nbatch, self.num_hidden)  # (T, B, embedding_dim)
        else:
            xS = xS.view(nbatch * self.S, self.x_features)
            x_emb = self.embx(xS).view(self.S, nbatch, self.num_hidden)
        # compute semb and add to x_emb
        xcemb = self._add_semb(x_emb, indices).view(nbatch, self.S*self.num_hidden)
        
        rt += self(t, ztS[0], zSemb[0], xcemb, indices[0])
        for s in range(1,self.S):
            zcemb = self.rnnz.recurse(ztS[s-1], zSemb[-1], indices[s])
            # print(zcemb.shape)
            zSemb.append(zcemb)
            rt += self(t, ztS[s], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[s])
        return rt/self.S


    def predictF(self, F, zF, xS, t=None, indices=None): # F>S
        if t is None:
            t = 0.
        if indices is None:
            indices = torch.arange(F, device=xS.device)/F
        zFemb = [None]
        xcemb = self.rnnx(xS, indices=indices).view(-1, self.num_hidden*self.num_layers)

        z0F = zF.clone().detach()
        z0F[0] = self.decode(zF[0], None, xS, indices[0], t)
        # z0S[0] = z0S0 / torch.norm(z0S0, dim=-1, keepdim=True)
        for f in range(1,F):
            zcemb = self.rnnz.recurse(z0F[f-1], zFemb[-1], indices[s])
            zFemb.append(zcemb)
            z0F[f] = self.decode(zF[f], zFemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[f], t)

        return z0F

    def log_prob(self, z: Tensor, zc:Tensor, zcemb: Tensor, xcemb: Tensor, s:int, t, prior) -> Tensor:
        if t is None:
            t = 0.
        I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
        I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

        xcemb = xcemb.clone().detach().requires_grad_(True)
        if zc is not None:
            zc = zc.clone().detach().requires_grad_(True)
            zcemb = zcemb.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = self(t, z, zcemb, xcemb, s)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0])
        z0, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.tensor([t, 1.0], device=z.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)

        # is this true? our prior is over the entire state-space
        return prior._log_prob(z0[-1], zc) + ladj[-1] * 1e2

    def log_probS(self, zS, xS, t, prior, indices=None):
        if indices is None:
            indices = torch.arange(self.S, device=xS.device)/self.S
        zSemb = [None]
        xcemb = self.rnnx(xS, indices=indices).view(-1, self.num_hidden*self.num_layers)
        # xcemb = self.embx(xSemb)

        out = self.log_prob(zS[0], None, zSemb[-1], xcemb, indices[0], t, prior)

        for s in range(1,self.S):
            zcemb = self.rnnz.recurse(zS[s-1], zSemb[-1], indices[s])
            zSemb.append(zcemb)
            out += self.log_prob(zS[s], zS[s-1], zSemb[-1].view(-1, self.num_hidden_z*self.num_layers), xcemb, indices[s], t, prior)
        return out



class FlowMatchingLoss(nn.Module):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, alpha=0.01, sig_min=1e-4, fixz=False):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.prior = prior
        self.sig_min = sig_min
        self.alpha = alpha
        self.fixz=fixz
    
    def forward_entropy(self, zS, xS, t):
        return - self.rt.log_probS(zS, xS, t, self.prior).mean()


    def forward(self, x: Tensor, indices=None) -> Tensor:
        # x has shape (n, S, x_dim)
        S, nbatch = x.shape[0], x.shape[1]
        if indices is None:
            indices = torch.arange(S, device=x.device)/S
        # print("xshape", x.shape)
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t

        x1 = torch.randn_like(x).to(device=x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x

        # axis (S, n, z_dim)
        z1 = self.prior.rsample(nbatch, S, device=x.device)
        zt = self.rt.decodeS(z1, xt.flatten(start_dim=2), _t[0], indices=indices)
        
        reg = 0
        reg_sq = 0

        fm_loss = (self.vt(_t, xt[0], zt[0], indices[0]) - ut[0]) # (1,n,xdim)
        for s in range(1,S):
            fm_loss += (self.vt(_t, xt[s], zt[s], indices[s]) - ut[s]) # (1,n,xdim)
            reg += (zt[s] - zt[s-1])**2 
            if s > 1:
                reg_sq += ((zt[s] - zt[s-1]) - (zt[s-1] - zt[s-2]))**2
        loss = fm_loss.square().mean(-1).mean()/S

        reg_loss1 = reg.mean()/S # continuity
        beta1 = loss.detach() / (torch.abs(reg_loss1.detach())+1e-8)
        reg_loss2 = reg_sq.mean()/S # smoothness
        beta2 = loss.detach() / (torch.abs(reg_loss2.detach())+1e-8)

        if not self.fixz:
            reg_loss = rt.square().mean()
            beta = loss.detach() / (torch.abs(reg_loss.detach())+1e-8)
            
            return loss + reg_loss * self.alpha * beta + reg_loss2 * self.alpha * beta2
        
        else:
            return loss + reg_loss2 * self.alpha * beta2

class FlowMatchingLossc(nn.Module):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, alpha=0.01, sig_min=1e-4, flowcnn=False, fixz=False):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.prior = prior
        self.sig_min = sig_min
        self.alpha = alpha
        self.cnn = flowcnn
        self.fixz = fixz
    
    def forward_entropy(self, zS, xS, t):
        return - self.rt.log_probS(zS, xS, t, self.prior).mean()


    def forward(self, x: Tensor, indices) -> Tensor:
        # print("x", x.shape)
        # x has shape (n, S, 1,p,p)
        S, nbatch = x.shape[0], x.shape[1]
        # print("xshape", x.shape)
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t

        x1 = torch.randn_like(x).to(device=x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x

        z1 = self.prior.rsample(nbatch, S, device=x.device)
        if self.cnn:
            zt = self.rt.decodeS(z1, xt, t=_t[0], indices=indices)
            if not self.fixz:
                rt = self.rt.forwardS(zt, xt, t=_t[0], indices=indices)
        else:
            zt = self.rt.decodeS(z1, xt.flatten(start_dim=2), t=_t[0], indices=indices)
            if not self.fixz:
                rt = self.rt.forwardS(zt, xt.flatten(start_dim=2), t=_t[0], indices=indices)

        # print("zt", zt.shape)
        reg = 0
        reg_sq = 0

        vt0 = self.vt(_t, xt[0], zt[0])
        # print("vt", vt.shape)
        fm_loss = (vt0 - ut[0]) # (1,n,xdim)

        for s in range(1,S):
            fm_loss += (self.vt(_t, xt[s], zt[s]) - ut[s]) # (1,n,xdim)

            reg += (zt[s] - zt[s-1])**2 
            if s > 1:
                reg_sq += ((zt[s] - zt[s-1]) - (zt[s-1] - zt[s-2]))**2

        loss = torch.mean(torch.norm(fm_loss/S, p=2, dim=(1,2,3))**2)
        reg_loss1 = reg.mean()/S # continuity
        beta1 = loss.detach() / (torch.abs(reg_loss1.detach())+1e-8)
        reg_loss2 = reg_sq.mean()/S # smoothness
        beta2 = loss.detach() / (torch.abs(reg_loss2.detach())+1e-8)

        if not self.fixz:
            reg_loss = rt.square().mean()
            beta = loss.detach() / (torch.abs(reg_loss.detach())+1e-8)
            
            return loss + reg_loss * self.alpha * beta + reg_loss2 * self.alpha * beta2
        else:
            return loss + reg_loss2 * self.alpha * beta2



class FlowMatchingLosscnn(nn.Module):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, flowcnn=False, alpha=0.01, sig_min=1e-4):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.prior = prior
        self.sig_min = sig_min
        self.alpha = alpha
        self.cnn = flowcnn
    
    def forward_entropy(self, zS, xS, t):
        return - self.rt.log_probS(zS, xS, t, self.prior).mean()


    def forward(self, x: Tensor, indices=None) -> Tensor:
        # x has shape (n, S, x_dim)
        S, nbatch, _, _, _ = x.shape
        # print("xshape", x.shape)
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t

        x1 = torch.randn_like(x).to(device=x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x
        # axis (S, n, z_dim)
        z1 = self.prior.rsample(nbatch, S, device=x.device)
        if self.cnn:
            zt = self.rt.decodeS(z1, xt, t=_t[0], indices=indices)
        else:
            zt = self.rt.decodeS(z1, xt.flatten(start_dim=2), t=_t[0], indices=indices)

        reg = 0
        reg_sq = 0
        # normalize the zt acorss S for each dimension.
        fm_loss = (self.vt(_t, xt[0], zt[0]) - ut[0]) # (n)ystem
        # print("loss shape", fm_loss.shape)
        for s in range(1,S):
            fm_loss += (self.vt(_t, xt[s], zt[s]) - ut[s]) # (n)
            reg += (zt[s] - zt[s-1]).square()
            if s > 1:
                reg_sq += ((zt[s] - zt[s-1]) - (zt[s-1] - zt[s-2])).square()
        loss = fm_loss.square().mean(-1).mean()/S
        reg_loss2 = reg_sq.mean()/S # smoothness
        beta2 = loss.detach() / (torch.abs(reg_loss2.detach())+1e-8)
        return loss + reg_loss2 * self.alpha * beta2


if __name__ == "__main__":

    batch_size = 5
    num_samples = 1_000
    dat_dir = "data"
    n = 2
    p = 10
    S = 50
    seed = 123
    dataloader = DataLoader(
        LDS(num_samples, n, p, S, dat_dir, sin=True, gen=True, plot=True, seed=seed), 
        batch_size=batch_size, shuffle=True)

    z_feature_dim = n
    x_feature_dim = p
    n_gen = 2
    n_epoch= 2 # 400

    llk_net = LLK(x_feature_dim, z_feature_dim, hidden_features=[64] * 3)
    pos_net = CNF(x_feature_dim, z_feature_dim, S, hidden_features=[32] * 2)
    prior = LatentDynamicalSystem(z_feature_dim)
    # Training
    loss = FlowMatchingLoss(llk_net, pos_net, prior)
    optimizer = torch.optim.Adam(
        itertools.chain(
                prior.parameters(),
                llk_net.parameters(),
                pos_net.parameters()
                ), 
        lr=0.001)
    step = 0
    torch.autograd.set_detect_anomaly(True)
    for epoch in tqdm(range(n_epoch)):
        for i, (x, y) in enumerate(dataloader):
            step += 1
            
            prior.train()
            pos_net.train()
            llk_net.train()
            # emb = y.view(-1,1).to(torch.float32)
            _loss = loss(x.reshape(S,-1,x_feature_dim))
            _loss.backward()
            print("step=", step, " loss=", _loss)

            optimizer.step()
            optimizer.zero_grad()

            if step % 20 == 0:
                # Sampling
                prior.eval()
                pos_net.eval()
                llk_net.eval()
                with torch.no_grad():
                    # z1 = torch.randint(0,5,(16384,1)).to(torch.float32)
                    z0 = prior.sample(n_gen, S) # (S, ngen, z_dim)
                    print("z0", z0.shape)
                    x0 = torch.randn(S, n_gen, x_feature_dim)
                    x1 = llk_net.decodeS(x0, z0)
                    print("x1", x1.shape)
                    z1 = pos_net.decodeS(z0, x1)

                # Convert tensors to numpy for logging and visualization
                x1_np = x1.cpu().detach().numpy().reshape(n_gen, S, x_feature_dim)
                z1_np = z1.cpu().detach().numpy().reshape(n_gen, S, z_feature_dim)

                fig, axes = plt.subplots(2, 1, figsize=(8, 4))
                # plt.hist2d(*x.T, bins=64)
                axes[0].imshow(x1_np[0].T, cmap="gray", origin="lower", aspect=.6)
                axes[0].set_xlabel("data")
                xlim = axes[0].get_xlim()
                ylim = axes[0].get_ylim()
                axes[1].plot(jnp.arange(S), z1_np[0])
                axes[1].set_xlabel("latent")
                axes[1].set_xmargin(0)
                plt.tight_layout()
                plt.savefig("gen_figs/lds_gen_step{}.png".format(step))
                plt.close()



