import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev, vmap
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid, softmax
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils as nn_utils
from torch.distributions import Dirichlet, Categorical, Normal, Uniform
from torchdiffeq import odeint_adjoint

from zuko.utils import odeint
from zuko.distributions import DiagNormal
from unet import *

from dataloader.dataloader_mnist import inv_transform

# from dataloader_pinwheel import *

torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)


def sum_except_batch(x):
    return x.view(x.size(0), -1).sum(-1)

def log_normal(x: Tensor) -> Tensor:
    return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2

def first_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = x - x.mean(0, keepdims=True)
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Identify the first Eigen direction
    first_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]

    # Step 4: Project the array onto the first Eigen direction
    projected_array = np.dot(x, first_eigenvector)

    return projected_array


class SoftplusParameterization(nn.Module):
    def forward(self, X):
        # return torch.exp(X)
        return nn.functional.softplus(X)

class ScaledSigmoid(nn.Module):
    def __init__(self):
        super(ScaledSigmoid, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return 2 * self.sigmoid(x) - 1

def logit_trans(data, eps=1e-6):
    data = eps + (1 - 2 * eps) * data
    return torch.log(data) - torch.log1p(-data)


class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False,
        dropout=False,
        weight_norm=False,
        layer_norm=False,
        p=0.2
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):  
            linear_layer = nn.Linear(a, b)
            if weight_norm:
                linear_layer = nn_utils.weight_norm(linear_layer)
            if batch_norm:
                layers.extend([linear_layer, nn.BatchNorm1d(b), fct])
            elif layer_norm:
                layers.extend([linear_layer, nn.LayerNorm(b), fct])
            elif dropout:
                layers.extend([linear_layer, nn.Dropout(p=p), fct])
            else:
                layers.extend([linear_layer, fct])

        if not weight_norm or batch_norm or layer_norm or dropout:
            super().__init__(*layers[:-1])
        else:
            super().__init__(*layers[:-2])



class ShiftedTanh(torch.nn.Module):
    def __init__(self, a=5, b=0.5):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return torch.tanh(x) * self.a + self.b


class SineActivation(torch.nn.Module):
    def __init__(self, omega=10):
        super().__init__()
        self.omega = omega

    def forward(self, x):
        return torch.sin(self.omega * x)

# Sample from the Gumbel-Softmax distribution and optionally discretize.
class GumbelSoftmax(nn.Module):

    def __init__(self, c_dim, temperature=1.0, hard=False):
        super(GumbelSoftmax, self).__init__()
        # self.logits = nn.Linear(f_dim, c_dim)
        # self.f_dim = f_dim
        self.c_dim = c_dim
        self.temperature = temperature
        self.hard = hard
        # self.device = device
     
    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax(y / self.temperature, dim=-1)

    def sample(self, logits):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        #categorical_dim = 10
        y = self.gumbel_softmax_sample(logits)

        if not self.hard:
            return y

        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        # Set gradients w.r.t. y_hard gradients w.r.t. y
        y_hard = (y_hard - y).detach() + y
        return y_hard 

    def forward(self, logits):
        # logits = self.logits(x).view(-1, self.c_dim)
        prob = F.softmax(logits, dim=-1)
        # y = self.sample(logits)
        # return logits, prob, y
        return prob


class GaussianMixtureComponent(nn.Module):
    def __init__(self, K, z_features, x_features, freqs=2, hidden_dim=784, cnn=False, in_ch=3, mod_ch=8, **kwargs):
        super().__init__()

        self.K = K
        self.hidden_dim = hidden_dim
        self.cnn = cnn

        if cnn:
            self.embx = Unet_nzt(in_ch, mod_ch=mod_ch, out_ch=1)
            # Fully connected layer to project CNN features to the hidden dimension
            self.cnn_fc = nn.Linear(x_features**2, hidden_dim)
        else:
            self.embx = nn.Linear(x_features, hidden_dim)

        self.emb = nn.Linear(K, hidden_dim)
        self.hyper =  nn.Sequential(
            MLP(2*hidden_dim+2*freqs, z_features*2, **kwargs),
            # nn.Tanh()
            )
        self.z_features = z_features
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        
    def forward(self, pi:Tensor, x:Tensor, t:Tensor):
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*pi.shape[:-1], -1)
        piemb = self.emb(pi)
        xemb = self.embx(x)
        if self.cnn:
            xemb = self.cnn_fc(xemb.view(len(x),-1))
        phi = self.hyper(torch.cat((temb, piemb, xemb), dim=-1))

        mu, sigma = phi.chunk(2, dim=-1)
        return DiagNormal(mu, F.softplus(sigma))

    def rsample(self, pi:Tensor, x:Tensor,t:Tensor=None):
        if t is None:
            t = torch.tensor(0., device=x.device)
        dist = self(pi, x, t)
        return dist.rsample()

    def sample(self, pi:Tensor, x:Tensor, t:Tensor=None):
        if t is None:
            t = torch.tensor(0., device=x.device)
        dist = self(pi, x, t)
        return dist.sample()

    def log_prob(self, pi:Tensor, x:Tensor, z:Tensor, t:Tensor=None):
        if t is None:
            t = torch.tensor(0., device=x.device)
        dist = self(pi, x, t)
        return dist.log_prob(z)

class update_GaussianMixtureComponent(nn.Module):
    def __init__(self, K, z_features, **kwargs):
        super().__init__()

        self.K = K
        self.hyper =  nn.Sequential(
            MLP(K, z_features, **kwargs)
            )
        self.z_features = z_features
        
    def forward(self, pi:Tensor):
        phi = self.hyper(pi)
        return DiagNormal(phi, torch.ones(self.z_features).to(pi.device) * .1)

    def sample(self, pi, size):
        return self(pi).sample(size)[0]

    def rsample(self, pi, size):
        return self(pi).rsample(size)[0]

    def log_prob(self, pi, z):
        return self(pi).log_prob(z)


class fixed_GaussianMixtureComponent():
    def __init__(self, K, z_features, device):
        super().__init__()

        self.K = K
        self.d = z_features
        self.device = device
        self.mu = self.maximin_init()
    
    def maximin_init(self):
        # Initialize the first point randomly within the unit cube [0, 1]^d
        points = [torch.rand(self.d, device=self.device)]
        
        for _ in range(1, self.K):
            # Calculate pairwise distances
            distances = torch.stack([torch.norm(points[-1] - p) for p in points])
            
            # Compute the new point as the farthest from the already chosen points
            new_point = torch.randn(self.d, device=self.device)
            
            while torch.min(torch.stack([torch.norm(new_point - p) for p in points])) < 1.5:  # Min distance threshold
                new_point = torch.randn(self.d, device=self.device)
                
            points.append(new_point)
            
        return torch.stack(points)

    def forward(self):
        return DiagNormal(self.mu, torch.ones(self.K, self.d).to(self.device) * .1)

    def sample(self, idx, size):
        return DiagNormal(self.mu[idx.argmax(-1)], torch.ones(self.d).to(self.device) * .1).sample(size)[0]

    def rsample(self, idx, size):
        return DiagNormal(self.mu[idx.argmax(-1)], torch.ones(self.d).to(self.device) * .1).rsample(size)[0]

    def log_prob(self, idx, z):
        return DiagNormal(self.mu[idx.argmax(-1)], torch.ones(self.d).to(self.device) * .1).log_prob(z)


class LLK(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super().__init__()
        hidden_dim = x_features # x_features # z_features

        self.embz = nn.Linear(z_features, hidden_dim)
        self.combine_fc = MLP(2*freqs+hidden_dim+x_features, x_features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        # embed t
        # print("x", x.shape)
        batch_size = x.size(0)

        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        
        zemb = self.embz(z)
        h = self.combine_fc(torch.cat((temb, z, x), dim=-1))

        return h

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def log_prob(self, x: Tensor, z: Tensor, t, prior) -> Tensor:
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        z = z.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        # print("x0", x0[-1].shape)
        # print("ladj", ladj[-1].shape)
        return prior.log_prob(x0[-1]).sum(-1) + ladj[-1] * 1e2

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        # Solve with an adaptive solver
        def solve_ldensity_i(t, x):
            """
            Solve ODE in reverse to evaluate the posterior
            """
            vtdt = self(t, x, z)
            return vtdt

        xt = odeint_adjoint(
            solve_ldensity_i, 
            x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return xt

class CatNF_fixed(nn.Module):
    def __init__(self, x_features: int, k: int, temp=1., hard=False, cnn=False, in_ch=3, hidden_dim=784, freqs=2, mod_ch=8, **kwargs):
        super(CatNF_fixed, self).__init__()
        # A simple feedforward network to output k * k transition rates
        # self.emb = EmbedCNN(1, z_features, z_features, hidden_features=[256, 256])
        self.k = k
        self.temp = temp
        self.hard = hard
        self.cnn = cnn
        self.register_buffer('freqs_t', torch.arange(1, freqs + 1) * torch.pi)
        

        if self.cnn:
            self.cnn = Unet_nzt(in_ch, mod_ch=mod_ch, out_ch=1)
            self.fc = nn.Sequential(
                MLP(x_features**2, k, **kwargs),
                nn.Tanh()
                )
        else:
            self.fc = nn.Sequential(
                MLP(x_features, k, **kwargs),
                nn.Tanh()
                )
            
        
    def _sample_gumbel(self, shape, device, eps=1e-20):
        U = torch.rand(shape, device=device)
        return -torch.log(-torch.log(U + eps) + eps)

    def _gumbel_softmax_sample(self, logits):
        y = logits + self._sample_gumbel(logits.size(), logits.device)
        return F.softmax(y / self.temp, dim=-1)

    def _sample(self, logits):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        #categorical_dim = 10
        y = self._gumbel_softmax_sample(logits)

        if not self.hard:
            return y

        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        # Set gradients w.r.t. y_hard gradients w.r.t. y
        y_hard = (y_hard - y).detach() + y
        return y_hard 

    def _forward(self, logits):
        prob = F.softmax(logits, dim=-1)
        return prob

    def forward(self, x: Tensor):

        if self.cnn:
            xemb = self.cnn(x).view(len(x), self.mod_ch)
            h = self.fc(xemb) 
            # h = self.fc(torch.cat([xemb, temb], axis=1))  
        else:
            h = self.fc(x.flatten(start_dim=1))  
            # h = self.fc(torch.cat([x.flatten(start_dim=1), temb], axis=1))       
        out = self._forward(h)
        return out

    def rsample(self, x: Tensor, logits=None) -> Tensor:
        if x is None and logits is not None:
            zt = self._sample(logits)
        else:
            if self.cnn:
                xemb = self.cnn(x)
                logits = self.fc(xemb.view(len(x), -1))
                # logits = self.fc(torch.cat([xemb.view(len(x), -1), temb], axis=1))
            else:
                xemb = x.flatten(start_dim=1)
                # print("t", temb.shape)
                logits = self.fc(xemb)
            zt = self._sample(logits)
        return logits, zt


class CatNF(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super(CatNF, self).__init__()
        self.k = z_features # hidden classes
        # A simple feedforward network to output k * k transition rates
        latent_dim = x_features # 256 # 128 # z_features
        self.embx = nn.Linear(x_features, z_features)
        self.embpi = nn.Linear(z_features, latent_dim)
        self.fc = MLP(
                z_features+latent_dim+2*freqs, 
                z_features, 
                **kwargs
                )


        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    # Kolmogorov backward (generating equation)
    def forward(self, t: Tensor, z:Tensor, x: Tensor):
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        # temb = self.embt(temb)
        xemb = self.embx(x)
        zemb = self.embpi(z)

        out = self.fc(torch.cat((temb, xemb, zemb), dim=-1))
        return -out

    def decode(self, z: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        # print("x", x.shape)
        # Solve with an adaptive solver

        zt = odeint_adjoint(
            partial(self, x=x), 
            z, 
            torch.Tensor([1., t]),
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return zt
        
    def log_prob(self, p: Tensor, x: Tensor, t, prior) -> Tensor:
        if t is None:
            t = 0.
        I = torch.eye(p.shape[-1], dtype=p.dtype, device=p.device)
        I = I.expand(*p.shape, p.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            p, ladj = state
            with torch.enable_grad():
                p = p.requires_grad_()
                dp = -self(t, p, x)

            jacobian = torch.autograd.grad(dp, p, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dp, trace * 1e-2

        ladj = torch.zeros_like(p[..., 0])
        # print("p", p.shape)
        # print("ladj", ladj.shape)
        pt, ladj = odeint_adjoint(
            augmented, 
            (p, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8

            )
        return prior.log_prob(pt[-1]) + ladj[-1] * 1e2

class CNF(nn.Module):
    def __init__(self, x_features: int, z_features: int, k: int, freqs: int = 2, cnn=False, in_ch=3, hidden_dim=784, **kwargs):
        super().__init__()
        # self.emb = MLP(x_features, x_features, hidden_features=[32, 32])
        # hidden_dim = 256 # 512 # z_features # x_features # 256 # 128 # z_features
        self.cnn = cnn
        self.mod_ch = 64 * (x_features // 4) * (x_features // 4)
        if cnn:
            self.cnn = nn.Sequential(
                nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 28 x 28
                nn.SiLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 32 x 14 x 14
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 14 x 14
                nn.SiLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 64 x 7 x 7
            )
            # Fully connected layer to project CNN features to the hidden dimension
            self.embx = nn.Linear(self.mod_ch, hidden_dim)
        else:
            # print("hidden dim", hidden_dim)
            self.embx = nn.Sequential(
                nn.Linear(x_features**2*in_ch, hidden_dim),
                )
        self.embpi = nn.Linear(k, hidden_dim)
        self.fc1 = nn.Sequential(
            # MLP(2*freqs+hidden_dim*3, z_features, 
            MLP(
                2*hidden_dim+z_features+2*freqs, z_features, 
                **kwargs
                ),
            )
        # # self.fc2 = MLP(z_features + x_features, z_features, **kwargs)

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, z:Tensor, pi:Tensor, x: Tensor) -> Tensor:
        # z = zs, zc = zs-1,

        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)
        piemb = self.embpi(pi)
        # print("x", x.shape)
        if self.cnn:
            _xemb = self.cnn(x)
            xemb = self.embx(_xemb.view(len(x), self.mod_ch))
        else:
            xemb = self.embx(x.flatten(start_dim=1))
        # zemb = self.embz(z)

        out = self.fc1(torch.cat((temb, xemb, z, piemb), dim=-1))
        
        # out = self.fc2(torch.cat((out, x), dim=-1))
        return -out
        # return self.net(temb, x, zemb)

    def decode(self, z: Tensor, pi: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        pi = pi.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, pi=pi, x=x), 
            z, 
            torch.tensor([1., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def decode_with_trajectory(self, z: Tensor, pi: Tensor, x: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 0.
        
        x = x.clone().detach().requires_grad_(True)
        pi = pi.clone().detach().requires_grad_(True)
        
        # Create time points
        start_time = 1.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state, pi=pi, x=x)
        
        # Run the ODE solver
        zt = odeint_adjoint(
            ode_func_with_recording, 
            z, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return zt[-1], trajectory_tensor, recorded_times_tensor

    def log_prob(self, z: Tensor, pi: Tensor, x: Tensor, t, prior) -> Tensor:
        with torch.enable_grad():
            I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
            I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

            x = x.clone().detach().requires_grad_(True)
            pi = pi.clone().detach().requires_grad_(True)
            # z = z.clone().detach().requires_grad_(True)

            def augmented(t: Tensor, state) -> Tensor:
                z, ladj = state
                with torch.enable_grad():
                    z = z.requires_grad_()
                    dz = self(t, z, pi, x)

                jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
                trace = torch.einsum('i...i', jacobian)

                return dz, trace * 1e-2
            
            ladj = torch.zeros_like(z[..., 0]).to(device=x.device)
            zt, ladj = odeint_adjoint(
                augmented, 
                (z, ladj), 
                torch.tensor([t, 1.0], device=x.device), 
                adjoint_params=self.parameters(),
                atol=1e-8, rtol=1e-8)
            priorlog = prior.log_prob(pi, zt[-1])
        # print("priorlog", priorlog.shape)
        return priorlog + ladj[-1] * 1e2


class CNF_z(nn.Module):
    def __init__(self, x_features: int, z_features: int, k: int, freqs: int = 2, cnn=False, in_ch=3, hidden_dim=784, **kwargs):
        super().__init__()
        # self.emb = MLP(x_features, x_features, hidden_features=[32, 32])
        # hidden_dim = 256 # 512 # z_features # x_features # 256 # 128 # z_features
        self.cnn = cnn
        self.mod_ch = 64 * (x_features // 4) * (x_features // 4)
        if cnn:
            self.cnn = nn.Sequential(
                nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 28 x 28
                nn.SiLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 32 x 14 x 14
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 14 x 14
                nn.SiLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 64 x 7 x 7
                ShiftedTanh()
            )
            # Fully connected layer to project CNN features to the hidden dimension
            self.embx = nn.Linear(self.mod_ch, hidden_dim)
        else:
            # print("hidden dim", hidden_dim)
            self.embx = nn.Sequential(
                nn.Linear(x_features**2*in_ch, hidden_dim),
                ShiftedTanh()
                )

        self.embpi = nn.Linear(k, hidden_dim)
        self.fc1 = nn.Sequential(
            # MLP(2*freqs+hidden_dim*3, z_features, 
            MLP(
                hidden_dim+z_features+2*freqs, z_features, 
                **kwargs
                ),
            nn.Tanh()

            )
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, z:Tensor, x: Tensor) -> Tensor:
        # z = zs, zc = zs-1,

        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*z.shape[:-1], -1)

        # print("x", x.shape)
        if self.cnn:
            _xemb = self.cnn(x)
            xemb = self.embx(_xemb.view(len(x), self.mod_ch))
        else:
            xemb = self.embx(x.flatten(start_dim=1))
        # zemb = self.embz(z)

        out = self.fc1(torch.cat((temb, xemb, z), dim=-1))
        
        return -out

    def decode(self, z: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, pi=pi, x=x), 
            z, 
            torch.tensor([1., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def decode_with_trajectory(self, z: Tensor, x: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 0.
        
        x = x.clone().detach().requires_grad_(True)
        
        # Create time points
        start_time = 1.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state, x=x)
        
        # Run the ODE solver
        zt = odeint_adjoint(
            ode_func_with_recording, 
            z, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return zt[-1], trajectory_tensor, recorded_times_tensor

    def log_prob(self, z: Tensor, x: Tensor, t, prior) -> Tensor:
        with torch.enable_grad():
            I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
            I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

            x = x.clone().detach().requires_grad_(True)
            # z = z.clone().detach().requires_grad_(True)

            def augmented(t: Tensor, state) -> Tensor:
                z, ladj = state
                with torch.enable_grad():
                    z = z.requires_grad_()
                    dz = self(t, z, x)

                jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
                trace = torch.einsum('i...i', jacobian)

                return dz, trace * 1e-2
            
            ladj = torch.zeros_like(z[..., 0]).to(device=x.device)
            zt, ladj = odeint_adjoint(
                augmented, 
                (z, ladj), 
                torch.tensor([t, 1.0], device=x.device), 
                adjoint_params=self.parameters(),
                atol=1e-8, rtol=1e-8)
            priorlog = prior.log_prob(pi, zt[-1])
        # print("priorlog", priorlog.shape)
        return priorlog + ladj[-1] * 1e2


class FlowMatchingLoss_fixed(nn.Module):
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, beta=1., sig_min=1e-4, eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.beta = beta
        self.eps = eps
        self.k = k

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

    def forward(self, x: Tensor, eps=1e-8) -> Tensor:
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = self.priory.sample((len(x),)).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logitst, ztidx = self.Rt.rsample(yt)
        pit = softmax(logitst/self.beta, dim=-1)
        zt = self.rt.forward(ztidx, yt).rsample()

        fm_loss = (self.vt(_t, yt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss = self.forward_entropy(pit).mean() * 0.001 - pit.var(0).mean()* 1 # + self.forward_var(pit, zt) * 0.0001
        alpha = loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)
        return loss + alpha * 0.001 * reg_loss
        # return loss - pit.var(0).mean()* 1.
        # return loss + self.forward_entropy(pit).mean()*0.1 - pit.var(0).mean()*1. + self.forward_var(pit, zt) * 0.0000001


class FlowMatchingLoss_z(nn.Module):
    def __init__(self, vt: nn.Module, rt, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.rt = rt
        self.priory = priory
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = torch.rand_like(x).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        z1 = self.priorz.rsample().to(device=x.device)
        zt = self.rt.decode(z1, yt, _t[0]) # given soft probability, sample z

        # print("reg", reg:q.mean())
        fm_loss = (self.vt(_t, yt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean() 
        # print("reg", self.forward_entropy(pit).mean())
        return loss 
        

class FlowMatchingLoss(nn.Module):
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = torch.rand_like(x).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)
        z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        z1 = self.priorz.rsample(z1idx, (1,))

        logitst = self.Rt.decode(logits1, yt, _t[0])
        ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
        zt = self.rt.decode(z1, ztidx, yt, _t[0]) # given soft probability, sample z
        pit = softmax(logitst/self.beta, dim=-1)

        fm_loss = (self.vt(_t, yt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss = self.forward_entropy(pit).mean() * 0.001 - pit.var(0).mean()* 1 + self.forward_var(pit, zt) * 0.0001
        alpha = loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)

        return loss + alpha * 0.001 * reg_loss


class FlowMatchingLossCNN(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

    def forward_norm(self, vt):
        return torch.mean(torch.norm(vt.flatten(start_dim=1), dim=-1)**2)
   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = self.priory.sample((len(x),)).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)
        # pi0 = softmax(logits1/self.beta, dim=-1)
        z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        z1 = self.priorz.rsample(z1idx, (1,))

        logitst = self.Rt.decode(logits1, yt.flatten(start_dim=1), _t[0])
        pit = softmax(logitst/self.beta, dim=-1)
        ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
        zt = self.rt.decode(z1, ztidx, yt.flatten(start_dim=1), _t[0]) # given soft probability, sample z
        
        vt = self.vt(_t, yt, zt)

        fm_loss = (vt - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss = self.forward_entropy(pit).mean() * 0.001 - pit.var(0).mean()* 1 + self.forward_var(pit, zt) * 0.0001
        alpha = loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)
        return loss + alpha * reg_loss * 0.001


class FlowMatchingLossCNN_fixpi(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.eps = eps

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi, eps=1e-8):
        return -(pi * torch.log(pi+eps)).sum(-1)

    def forward_KL(self,pi,z,x,t):
        return (self.rt.log_prob(z,pi,x,t,self.priorz) - self.priorz.log_prob(pi,z)).mean()

    def forward_norm(self, vt):
        return torch.mean(torch.norm(vt.flatten(start_dim=1), dim=-1)**2)
   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = self.priory.sample((len(x),)).to(x.device)
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)
        # z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        _, z1idx = self.Rt.rsample(None, logits1)
        z1 = self.priorz.rsample(z1idx, (1,))
        
        # print("yt", yt.shape)
        logitst, ztidx = self.Rt.rsample(yt)
        pit = softmax(logitst/self.beta, dim=-1)
        zt = self.rt.decode(z1, ztidx, yt, _t[0]) # given soft probability, sample z
        vt = self.vt(_t, yt, zt)

        # print("reg", reg:q.mean())
        fm_loss = (vt - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss1 = self.forward_entropy(pit).mean()
        reg_loss2 = - pit.var(0).mean() # + torch.sqrt((zt**2).sum(-1)).mean() * 0.01 # + zdiff_term*0.0001 # + self.forward_var(pit, zt) * 0.0001
        reg_loss3 = torch.sqrt((zt**2).sum(-1)).mean()
        # reg_loss = - pit.var(0).mean()* 1 + torch.sqrt((zt**2).sum(-1)).mean() * 0.01
        beta1 = loss.detach() / (torch.abs(reg_loss1.detach()) + 1e-8)
        beta2 = loss.detach() / (torch.abs(reg_loss2.detach()) + 1e-8)
        beta3 = loss.detach() / (torch.abs(reg_loss3.detach()) + 1e-8)
        # print("reg", self.forward_entropy(pit).mean())
        return loss + self.alpha * 0.01 * beta1 * reg_loss1 +  self.alpha * beta2 * reg_loss2 + beta3 * reg_loss3 * 0.001






class FlowMatchingLossCNN_fixed(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8, cnn=False):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps
        self.cnn = cnn

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

    def forward_norm(self, vt):
        return torch.mean(torch.norm(vt.flatten(start_dim=1), dim=-1)**2)

    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        # y1 = torch.rand_like(x).to(x.device)
        y1 = self.priory.sample((len(x),)).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        if self.cnn:
            _yt = yt
        else:
            _yt = yt.flatten(start_dim=1)

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)

        logitst, ztidx = self.Rt.rsample(_yt, _t[0])
        pit = softmax(logitst/self.beta, dim=-1)

        ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
        zt = self.rt.forward(ztidx, _yt, _t[0]).rsample()

        vt = self.vt(_t, yt, zt)
        fm_loss = (vt - ut).square().mean(-1)

        loss = fm_loss.mean() 
        reg_loss1 = - pit.var(0).mean()# + self.forward_var(pit, zt) * 0.0001
        beta1 = loss.detach() / (torch.abs(reg_loss1).detach() + 1e-8)
        reg_loss2 = self.forward_entropy(pit).mean()
        beta2 = loss.detach() / (torch.abs(reg_loss2).detach() + 1e-8)
        return loss + beta1 * self.alpha * reg_loss1 # + beta2 * self.alpha * reg_loss2



class FlowMatchingLoss_fixz(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps


    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = self.priory.sample((len(x),)).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)

        logitst = self.Rt.decode(logits1, yt, _t[0])
        ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
        pit = softmax(logitst/self.beta, dim=-1)
        zt = self.rt.forward(ztidx, yt).rsample()

        # print("reg", reg:q.mean())
        fm_loss = (self.vt(_t, yt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss = self.forward_entropy(pit).mean()*.001 - pit.var(0).mean()* 1.
        alpha = loss.detach() / (reg_loss.detach() + 1e-8)
        # print("reg", self.forward_entropy(pit).mean())
        return loss + reg_loss * 0.001


class FlowMatchingLoss_fixpi(nn.Module):
    # fix pi and vary z
    def __init__(self, vt: nn.Module, Rt: nn.Module, rt, priorpi, priorz, priory, k: int, sig_min=1e-4, beta=0.1, alpha=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        # self.xemb = xemb
        self.priory = priory
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.alpha = alpha
        self.tau = tau
        self.eps = eps

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        y1 = self.priory.sample((len(x),)).to(x.device)
        y = x # use Sinusoidal Activation so no longer need to convert to real value
        yt = (1 - t) * y + (self.sig_min + (1 - self.sig_min) * t) * y1
        ut = (1 - self.sig_min) * y1 - y

        logits1 = self.priorpi.rsample((len(x),)).to(x.device)
        # z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
        _, z1idx = self.Rt.rsample(None, logits1)
        z1 = self.priorz.rsample(z1idx, (1,))
        
        logitst, ztidx = self.Rt.rsample(yt)
        pit = softmax(logitst/self.beta, dim=-1)
        zt = self.rt.decode(z1, ztidx, yt, _t[0]) # given soft probability, sample z

        # print("reg", reg:q.mean())
        fm_loss = (self.vt(_t, yt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean() 
        reg_loss = - pit.var(0).mean()* 1.+ self.forward_entropy(pit).mean()*.01 # + self.forward_var(pit, zt) * 0.000001
        alpha = loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)
        return loss + alpha * reg_loss * 0.001




class cnnLLK(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, in_ch: int = 1, mod_ch: int = 128, hidden_dim = 784, num_blocks=4, model="cnn", droprate=0.2, **kwargs):
        super().__init__()
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        self.x_features = x_features
        self.in_ch = in_ch
        self.model = model
        self.base_ch = mod_ch
        self.num_blocks = num_blocks
        self.shiftedtanh = ShiftedTanh()
        if model == "resnet":
            
            self.input_proj = nn.Conv2d(self.in_ch, self.base_ch, 3, padding=1)
            self.res_blocks = nn.Sequential(
                *[ResBlock_nzt(self.base_ch, self.base_ch, droprate) for _ in range(self.num_blocks)])
            self.output_proj = nn.Sequential(
                nn.Conv2d(self.base_ch, self.in_ch, 3, padding=1),
                # ShiftedTanh()
                )
            self.latent_fc = nn.Linear(z_features, x_features**2*in_ch)
            self.temb_fc = nn.Linear(2*freqs, x_features**2)
            self.combine_fc = self.fusion = nn.Sequential(
                    nn.Conv2d(in_ch*2+1, in_ch, kernel_size=3, padding=1),
                )

        elif model == "unet":
            self.fc = Unet_nzt(self.in_ch, mod_ch, in_ch)

            self.cnn_fc = nn.Linear(self.in_ch * (x_features)**2, hidden_dim)
            self.latent_fc = nn.Linear(z_features, hidden_dim)
            self.combine_fc = MLP(hidden_dim*2+2*freqs, x_features**2*in_ch, **kwargs)

        else:
            self.cnn = nn.Sequential(
                nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 28 x 28
                nn.Softplus(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 32 x 14 x 14
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 14 x 14
                nn.Softplus(),
                nn.MaxPool2d(kernel_size=2, stride=2)  # Output: 64 x 7 x 7
            )
            # Fully connected layer to project CNN features to the hidden dimension
            self.cnn_fc = nn.Linear(64 * (x_features // 4) * (x_features // 4), hidden_dim)
            self.latent_fc = nn.Linear(z_features, hidden_dim)
            # Combine the processed inputs and output the vector field
            self.combine_fc = MLP(hidden_dim+hidden_dim+2*freqs, x_features**2*in_ch, **kwargs)
       

    def forward(self, t: Tensor, x: Tensor, z:Tensor):
        _t = self.freqs * t[..., None]
        temb = torch.cat((_t.cos(), _t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:1], -1)    
        B,C,H,W = x.shape

        if self.model == "resnet":
            xemb = self.input_proj(x)
            xemb = self.res_blocks(xemb)
            xemb = self.output_proj(xemb)
            zemb = self.latent_fc(z)[:,:,None, None]
            temb = temb[:,:,None, None]
            out = self.combine_fc(torch.cat([xemb, zemb, temb], dim=1))
            return out
        elif self.model == "unet":
            xemb = self.fc(x)
            xemb = self.cnn_fc(xemb.view((B, -1)))
        else:
            # Extract features from the MNIST image using CNN
            xemb = self.cnn(x)  # Output: [batch_size, 64, 7, 7]
            # Flatten: [batch_size, 64 * 7 * 7]
            xemb = self.cnn_fc(xemb.view(B, -1))  # Map to hidden_dim: [batch_size, hidden_dim]

        # Process latent state and time embedding
        zemb = self.latent_fc(z)  # [batch_size, hidden_dim]

        # Combine all features
        combined_features = torch.cat([xemb, zemb, temb], dim=-1)  # [batch_size, hidden_dim * 3]
        vector_field = self.combine_fc(combined_features)  # [batch_size, image_size * image_size]

        # reshape to same as x
        return vector_field.view(B, C, H, W)


    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        # Solve with an adaptive solver

        xt = odeint_adjoint(
            partial(self, z=z), 
            x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decode_with_trajectory(self, x: Tensor, z: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 1.
        
        z = z.clone().detach().requires_grad_(True)
        
        # Create time points
        start_time = 0.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state, z=z)
        
        # Run the ODE solver
        xt = odeint_adjoint(
            ode_func_with_recording, 
            x, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return xt[-1], trajectory_tensor, recorded_times_tensor

    @staticmethod
    def exact_trace(f, y):
        """Exact Jacobian trace"""
        # Check if f.sum() is differentiable with respect to x
        # print("Autograd grad test:", torch.autograd.grad(f.sum(), y, allow_unused=True))  # Should not be None
        dims = y.size()[1:]
        tr_dzdx = 0.0
        dim_ranges = [range(d) for d in dims]
        for idcs in itertools.product(*dim_ranges):
            batch_idcs = (slice(None),) + idcs
            tr_dzdx += torch.autograd.grad(f[batch_idcs].sum(), y, create_graph=True)[0][batch_idcs]
        return tr_dzdx

    @staticmethod
    def hutch_trace(f, y, e):
        """Hutchinson's estimator for the Jacobian trace"""
        # With _eps ~ Rademacher (== Bernoulli on -1 +1 with 50/50 chance).
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = sum_except_batch(e_dzdx_e)
        return approx_tr_dzdx

    def log_prob(self, x: Tensor, z: Tensor, t, source) -> Tensor:
        z = z.clone().detach().requires_grad_(True)
        e = torch.randint(low=0, high=2, size=x.size()).to(x.device) * 2 - 1
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x.requires_grad_(True)
                dx = self(t, x, z)
                trace = self.hutch_trace(dx,x,e)
            return dx, trace * 1e-3

        ladj = x.new_zeros(x.shape[0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        return source.log_prob(x0[-1]).sum(dim=(1, 2, 3)) + ladj[-1] * 1e3


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim

        self.fc = nn.Sequential(
            nn.Linear(input_dim, emb_dim),
            nn.Tanh(),
            nn.Linear(emb_dim, emb_dim),
            )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.fc(x)


class LLK_R(nn.Module):
    def __init__(self, x_features: int, z_features: int, k: int, freqs: int = 2, **kwargs):
        super().__init__()

        self.fc = MLP(2 * freqs + x_features + z_features, x_features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        h = self.fc(torch.cat((temb, z, x), dim=-1))
        return h # apply sigmoid transformation to x

    def _forward(self, t: Tensor, x: Tensor, z:Tensor) -> Tensor:
        out = self.forward(t, x, z)
        return out, out

    def decode(self, x: Tensor, z: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        z = z.clone().detach().requires_grad_(True)
        
        # Solve with an adaptive solver
        def solve_ldensity_i(t, x):
            """
            Solve ODE in reverse to evaluate the posterior
            """
            vtdt = self(t, x, z)
            return vtdt

        xt = odeint_adjoint(
            solve_ldensity_i, 
            x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decode_with_trajectory(self, x: Tensor, z: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 1.
        
        z = z.clone().detach().requires_grad_(True)
        
        # Create time points
        start_time = 0.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state, z=z)
        
        # Run the ODE solver
        xt = odeint_adjoint(
            ode_func_with_recording, 
            x, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return xt[-1], trajectory_tensor, recorded_times_tensor
        

    def log_prob(self, x: Tensor, z: Tensor, t=0.):
        I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)

        z = z.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x, z)

            jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        return log_normal(x0[-1]) + ladj[-1] * 1e2


class pi_LLK_R(nn.Module):
    def __init__(self, x_features: int, k: int, freqs: int = 2, **kwargs):
        super().__init__()
        hidden_dim = k
        self.embx = nn.Linear(x_features, hidden_dim)
        self.embpi = nn.Linear(k, hidden_dim)

        self.fc = MLP(
            # 2 * freqs + x_features + k, 
            2*freqs+2*hidden_dim,
            x_features, **kwargs)
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor, pi:Tensor) -> Tensor:
        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        piemb = self.embpi(pi)
        xemb = self.embx(x)
        h = self.fc(torch.cat((temb, piemb, xemb), dim=-1))
        return h # apply sigmoid transformation to x

    def _forward(self, t: Tensor, x: Tensor, pi:Tensor) -> Tensor:
        out = self.forward(t, x, pi)
        return out, out


    def decode(self, x: Tensor, pi:Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        pi = pi.clone().detach().requires_grad_(True)
        # Solve with an adaptive solver
        def solve_ldensity_i(t, x):
            """
            Solve ODE in reverse to evaluate the posterior
            """
            vtdt = self(t, x, pi)
            return vtdt

        xt = odeint_adjoint(
            solve_ldensity_i, 
            x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return xt


class CatNF_R(nn.Module):
    def __init__(self, x_features: int, z_features: int, freqs: int = 2, **kwargs):
        super(CatNF_R, self).__init__()
        self.k = z_features # hidden classes
        # A simple feedforward network to output k * k transition rates
        
        hidden_dim = z_features
        self.embx = nn.Linear(x_features, hidden_dim)
        self.embz = nn.Linear(z_features, hidden_dim)
        self.fc = nn.Sequential(
            MLP(
                # 2 * freqs + x_features + z_features, 
                2 * freqs + hidden_dim*2,
                self.k, 
                # fct=ScaledSigmoid(), 
                **kwargs),
            # nn.Sigmoid()
            )
        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    # Kolmogorov backward (generating equation)
    def forward(self, t: Tensor, z:Tensor, x: Tensor):
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        xemb = self.embx(x)
        zemb = self.embz(z)
        return -self.fc(torch.cat((temb, xemb, zemb), dim=-1))
        
    def decode(self, z: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        # Solve with an adaptive solver
        zt = odeint_adjoint(
            partial(self, x=x), 
            z, 
            torch.Tensor([1., t]), 
            # torch.Tensor([0., t]), # 1 is noise, 0 is data
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            # method="dopri8",
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def log_prob(self, z: Tensor, x: Tensor, t, prior) -> Tensor:
        if t is None:
            t = 0.
        I = torch.eye(p.shape[-1], dtype=p.dtype, device=p.device)
        I = I.expand(*p.shape, p.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = -self(t, z, x)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0])
        # print("p", p.shape)
        # print("ladj", ladj.shape)
        z0, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8
            )
        return prior.log_prob(z0[-1]) + ladj[-1] * 1e2

class CNF_R(nn.Module):
    def __init__(self, x_features: int, z_features: int, k: int, freqs: int = 2, **kwargs):
        super().__init__()
        # self.emb = MLP(x_features, x_features, hidden_features=[32, 32])
        # self.emb = nn.Sequential(nn.Linear(x_features, x_features))
        hidden_dim = k # k
        self.embpi = nn.Linear(k, hidden_dim)
        self.embz = nn.Linear(z_features, hidden_dim)
        self.embx = nn.Linear(x_features, hidden_dim)
        self.fc1 = nn.Sequential(
            MLP(
                2*freqs + hidden_dim*3,
                z_features, **kwargs
                ),
            # ShiftedTanh()
            )

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, z:Tensor, pi:Tensor, x: Tensor) -> Tensor:
        # z = zs, zc = zs-1,

        # embed t
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:-1], -1)
        xemb = self.embx(x)
        zemb = self.embz(z)
        piemb = self.embpi(pi)
        # print("dim", x.shape)
        out = self.fc1(torch.cat((temb, xemb, zemb, piemb), dim=-1))

        return -out
        # return self.net(temb, x, zemb)

    def decode(self, z: Tensor, pi: Tensor, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 0.
        x = x.clone().detach().requires_grad_(True)
        pi = pi.clone().detach().requires_grad_(True)
        zt = odeint_adjoint(
            partial(self, pi=pi, x=x), 
            z, 
            torch.tensor([1., t], device=x.device), 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)[-1]
        return zt

    def decode_with_trajectory(self, z: Tensor, pi: Tensor, x: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 0.
        
        x = x.clone().detach().requires_grad_(True)
        pi = pi.clone().detach().requires_grad_(True)
        
        # Create time points
        start_time = 1.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state, pi=pi, x=x)
        
        # Run the ODE solver
        zt = odeint_adjoint(
            ode_func_with_recording, 
            z, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return zt[-1], trajectory_tensor, recorded_times_tensor


    def log_prob(self, z: Tensor, pi: Tensor, x: Tensor, t, prior) -> Tensor:

        I = torch.eye(z.shape[-1], dtype=z.dtype, device=z.device)
        I = I.expand(*z.shape, z.shape[-1]).movedim(-1, 0)

        x = x.clone().detach().requires_grad_(True)
        pi = pi.clone().detach().requires_grad_(True)
        def augmented(t: Tensor, state) -> Tensor:
            z, ladj = state
            with torch.enable_grad():
                z = z.requires_grad_()
                dz = self(t, z, pi, x)

            jacobian = torch.autograd.grad(dz, z, I, create_graph=True, is_grads_batched=True)[0]
            trace = torch.einsum('i...i', jacobian)

            return dz, trace * 1e-2

        ladj = torch.zeros_like(z[..., 0]).to(device=x.device)
        zt, ladj = odeint_adjoint(
            augmented, 
            (z, ladj), 
            torch.tensor([t, 1.0], device=x.device), 
            adjoint_params=self.parameters(),
            atol=1e-8, rtol=1e-8)
        return prior.log_prob(pi, zt[-1]) + ladj[-1] * 1e2


class pi_FlowMatchingLoss_R(nn.Module):
    def __init__(self, vt: nn.Module, Rt: nn.Module, priorpi, k: int=None, sig_min=1e-4, beta=0.1, tau=1.):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        # self.rt = rt
        self.priorpi = priorpi # dirichlet here
        # self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.beta = beta
        self.tau=tau

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)

    def forward_MI(self, p, x, t):
        return self.rt.log_prob(p, x, t, self.prior) - self.prior.log_prob(p)
        # return self.vt.log_prob(x, p, t) - log_normal(p)

    def forward(self, x: Tensor, beta: float) -> Tensor:
        # print("eval loss")
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        x1 = torch.randn_like(x).to(x.device)
        # y = torch.logit(x, eps=1e-5) # transform [0,1] to R
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x
        logits1 = self.priorpi.rsample((len(x),)).to(x.device) # prior sampled from dirichlet
        # q0 = torch.ones((len(x), self.k))/self.k
        logitst = self.Rt.decode(logits1, xt, _t[0]) # input pixel value [0,1]
        # convert to simplex
        pit = softmax(logitst/beta, dim=-1)
        zt = F.gumbel_softmax(logitst, tau=self.tau, hard=False)

        fm_loss = (self.vt(_t, xt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean()
        # loss = fm_loss
        return loss + self.forward_entropy(pit).mean() * 0.001 - pit.var(0).mean() * 1


class FlowMatchingLoss_R(nn.Module):
    def __init__(
        self, vt: nn.Module, Rt: nn.Module, rt: nn.Module, priorpi, priorz, 
        k: int=None, fix_z=False, fix_pi=False,
        sig_min=1e-4, beta=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.rt = rt
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.tau = tau
        self.beta = beta
        self.eps = eps
        self.fix_z = fix_z
        self.fix_pi = fix_pi

        print("prior form z", priorz)

    def forward_MI(self, p, x, t):
        return self.rt.log_prob(p, x, t, self.prior) - self.prior.log_prob(p)
        # return self.vt.log_prob(x, p, t) - log_normal(p)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)
        # return 1 - (pi ** 2).sum(-1)

    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward(self, x: Tensor, hard=False) -> Tensor:
        # print("eval loss")
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        x1 = torch.randn_like(x).to(x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x
        
        # ztidx = F.gumbel_softmax(pit, tau=self.tau, hard=True)
        if self.fix_z and self.fix_pi:
            logitst, ztidx = self.Rt.rsample(xt)
            zt = self.rt.rsample(ztidx, xt)
            pit = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
        elif self.fix_z:
            logits1 = self.priorpi.rsample((len(x),)).to(x.device) # prior sampled from dirichlet
            logitst = self.Rt.decode(logits1, xt, _t[0])
            pit = softmax(logitst/self.beta, dim=-1)
            ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
            zt = self.rt.rsample(ztidx, xt)
        elif self.fix_pi:
            logits1 = self.priorpi.rsample((len(x),)).to(x.device) # prior sampled from dirichlet
            _, z1idx = self.Rt.rsample(None, logits1)
            z1 = self.priorz.rsample(z1idx, (1,)).to(x.device)

            logitst, ztidx = self.Rt.rsample(xt)
            pit = softmax(logitst/self.beta, dim=-1)
            zt = self.rt.decode(z1, ztidx, xt)
        else:
            logits1 = self.priorpi.rsample((len(x),)).to(x.device) # prior sampled from dirichlet
            z1idx = F.gumbel_softmax(logits1, tau=self.tau, hard=hard)
            z1 = self.priorz.rsample(z1idx, (1,)).to(x.device)

            logitst = self.Rt.decode(logits1, xt, _t[0])
            pit = softmax(logitst/self.beta, dim=-1)
            ztidx = F.gumbel_softmax(logitst, tau=self.tau, hard=hard)
            zt = self.rt.decode(z1, ztidx, xt)

        fm_loss = (self.vt(_t, xt, zt) - ut).square().mean(-1)

        loss = fm_loss.mean()
        reg_loss = - pit.var(0).mean() + self.forward_entropy(pit).mean() * 0.001  # + self.forward_var(pit, zt) * 0.0001
        alpha = loss.detach() / (torch.abs(reg_loss.detach()) + 1e-8)
        # loss = fm_loss
        return loss + alpha * 0.001 * reg_loss

class fixed_FlowMatchingLoss_R(nn.Module):
    def __init__(self, vt: nn.Module, Rt: nn.Module, r: nn.Module, priorpi, priorz, k: int=None, sig_min=1e-4, beta=0.1, tau=1., eps=1e-8):
        super().__init__()

        self.vt = vt
        self.Rt = Rt
        self.r = r
        self.priorpi = priorpi # dirichlet here
        self.priorz = priorz # gaussian here
        self.sig_min = sig_min
        self.k = k
        self.tau = tau
        self.beta = beta
        self.eps = eps

        print("prior form z", priorz)

    def forward_MI(self, p, x, t):
        return self.rt.log_prob(p, x, t, self.prior) - self.prior.log_prob(p)
        # return self.vt.log_prob(x, p, t) - log_normal(p)
    def forward_var(self, pi, z):
        weighted_z_sum = torch.matmul(pi.T, z) 
        class_sum = pi.sum(dim=0, keepdim=True).T 
        mu = weighted_z_sum / (class_sum + self.eps)
        # return -mu.var()
        mu_diff = mu.unsqueeze(0) - mu.unsqueeze(1)  # Shape: (5, 5, 2), pairwise differences between centroids
        between_class_distances = torch.sum(mu_diff ** 2, dim=2).sum()  # Shape: (5, 5), squared distance matrix
        return - torch.sqrt(between_class_distances)

    def forward_entropy(self, pi):
        return -(pi * torch.log(pi)).sum(-1)
        # return 1- (pi ** 2).sum(-1)

    def forward(self, x: Tensor, beta: float, hard=False) -> Tensor:
        # print("eval loss")
        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t
        x1 = torch.randn_like(x).to(x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = (1 - self.sig_min) * x1 - x
        # latent probability vector
        pi1 = self.priorpi.rsample((len(x),)).to(x.device) # prior sampled from dirichlet
        pit = self.Rt.decode(pi1, xt, _t[0])
        pit = softmax(pit/beta, dim=-1)

        ztidx = F.gumbel_softmax(pit.log(), tau=self.tau, hard=hard)
        zt = self.r.forward(ztidx, xt).rsample()
        fm_loss = (self.vt(_t, xt, zt) - ut).square().mean(-1)
        loss = fm_loss.mean()
        # loss = fm_loss
        # return loss - reg*0.1
        return loss + self.forward_entropy(pit).mean()*0.001 - pit.var(0).mean() + self.forward_var(pit, zt) * 0.0001
        # + (self.forward_var(ztidx, zt)-.6)**2 * 0.001


def to_one_hot(z, num_classes):
    """
    Transforms a batch of categorical values into one-hot encoded form.
    
    Args:
    - z (torch.Tensor): A tensor of shape (batch_size,) with categorical values (integers).
    - num_classes (int): The number of unique categories/classes.
    
    Returns:
    - torch.Tensor: A tensor of shape (batch_size, num_classes) representing the one-hot encoded values.
    """
    # Ensure the input batch is a tensor
    if not isinstance(z, torch.Tensor):
        z = torch.tensor(z)
    
    # Create a one-hot encoding of the batch
    one_hot = torch.zeros(z.size(0), num_classes)
    one_hot.scatter_(1, z.unsqueeze(1), 1)
    return one_hot


if __name__ == '__main__':

    # data, _ = make_moons(16384, noise=0.05)
    # data = torch.from_numpy(data).float()

    batch_size = 1000
    num_per_class = 10_000
    dat_dir = "data"
    num_classes = 5
    dataloader = DataLoader(PINWHEEL(num_per_class, dat_dir, num_classes=num_classes, gen=True, plot=False), batch_size=batch_size, shuffle=True)
    z_feature_dim = 1
    x_feature_dim = 2
    n_gen = 1000
    n_epoch= 50 # 400
    beta = 0.1

    llk_net = LLK_R(x_feature_dim, z_feature_dim, hidden_features=[120] * 3)
    pi_net = CatNF_R(x_feature_dim, num_classes, hidden_features=[64] * 2)
    z_net = CNF(x_feature_dim, z_feature_dim, num_classes, hidden_features=[64] * 2)
    # priorpi = Dirichlet(torch.ones(num_classes)/num_classes)
    priorpi =DiagNormal(torch.zeros(num_classes), torch.ones(num_classes)) 
    priorz = DiagNormal(torch.zeros(z_feature_dim), torch.ones(z_feature_dim))
    loss = FlowMatchingLoss_R(llk_net, pi_net, z_net, priorpi, priorz, num_classes)
    # Training
    optimizer = torch.optim.Adam(
        itertools.chain(
                llk_net.parameters(),
                pi_net.parameters(),
                z_net.parameters()
                ), 
        lr=1e-3)


    for epoch in tqdm(range(n_epoch)):
        for i, (x, y) in enumerate(dataloader):
            # print("class", y[1])
            # emb = y.view(-1,1).to(torch.float32)
            # one_hot_y = to_one_hot(y, z_feature_dim)
            # print(one_hot_y)
            loss(x).backward()

            optimizer.step()
            optimizer.zero_grad()

    # Sampling
    with torch.no_grad():
        # z1 = torch.randint(0,5,(16384,1)).to(torch.float32)
        # prior pmf
        pi0 = priorpi.sample((n_gen,))
        z0 = priorz.sample((n_gen,))
        x0 = torch.randn(n_gen, x_feature_dim)
        x1 = llk_net.decode(x0, z0)
        pi1 = pi_net.decode(pi0, x1)
        pi1 = softmax(pi1/beta, dim=-1)
        z1 = z_net.decode(z0, pi1, x1)
        
        # z1 = torch.multinomial(q1, 1).view(-1)

    cmap = plt.colormaps['gist_rainbow']
    plt.figure(dpi=150)
    # plt.hist2d(*x.T, bins=64)
    plt.scatter(x1[:,0], x1[:,1], c=z1, marker=".", cmap=cmap)
    plt.colorbar()
    plt.show()
    plt.close()

