import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum
from utils.helpers import VectorMLP, VectorUNet
from utils.helpers import SinusoidalPosEmb, linear_beta_schedule, cosine_beta_schedule, extract

class MLPDiffusion(nn.Module):
    """
    MLP-based diffusion model for action generation.
    
    This model uses a simple Multi-Layer Perceptron (MLP) architecture to learn
    the denoising function in the diffusion process. It takes noisy actions,
    timesteps, and states as input and predicts either the noise to be removed
    or the clean action directly.
    """
    
    def __init__(self, state_dim, action_dim, time_dim=32, hidden_dim=256):
        """
        Initialize the MLP diffusion model.
        
        Args:
            state_dim (int): Dimension of the state space
            action_dim (int): Dimension of the action space
            time_dim (int): Dimension of time embeddings
            hidden_dim (int): Hidden layer dimension for the MLP
        """
        super().__init__()
        
        # Time embedding network to encode diffusion timesteps
        # Uses sinusoidal position embeddings similar to transformers
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.Mish(),
            nn.Linear(time_dim * 2, time_dim),
        )
        
        # Input dimension: action + state + time embedding
        input_dim = action_dim + state_dim + time_dim
        
        # Main MLP network for denoising
        # Uses Mish activation for better gradient flow
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, action_dim),
        )

    def forward(self, x, t, s):
        """
        Forward pass of the MLP diffusion model.
        
        Args:
            x (torch.Tensor): Noisy actions [batch_size, action_dim]
            t (torch.Tensor): Timesteps [batch_size]
            s (torch.Tensor): States [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Predicted noise or clean actions [batch_size, action_dim]
        """
        # Encode timestep into embeddings
        t_emb = self.time_mlp(t.float())
        
        # Concatenate noisy action, state, and time embedding
        x_input = torch.cat([x, s, t_emb], dim=-1)
        
        # Pass through MLP to predict noise/action
        return self.mlp(x_input)


class UNetDiffusion(nn.Module):
    """
    UNet-based diffusion model for action generation.
    
    This model uses a UNet-like architecture with skip connections for the
    denoising function. UNet architectures are particularly effective for
    diffusion models as they preserve both local and global information
    through their encoder-decoder structure with skip connections.
    """
    
    def __init__(self, state_dim, action_dim, time_dim=32, hidden_dim=256):
        """
        Initialize the UNet diffusion model.
        
        Args:
            state_dim (int): Dimension of the state space
            action_dim (int): Dimension of the action space
            time_dim (int): Dimension of time embeddings
            hidden_dim (int): Base hidden layer dimension
        """
        super().__init__()
        
        # Time embedding network (same as MLP version)
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.Mish(),
            nn.Linear(time_dim * 2, time_dim),
        )
        
        input_dim = action_dim + state_dim + time_dim
        
        # Encoder path - progressively reduces spatial information
        # while increasing feature dimensions
        self.encoder1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Mish(),
        )
        
        self.encoder2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.Mish(),
        )
        
        # Bottleneck layer - processes the most compressed representation
        self.bottleneck = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.Mish(),
        )
        
        # Decoder path - progressively reconstructs the output
        # Skip connections (*4 and *2) concatenate encoder features
        self.decoder2 = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),  # *4 due to skip connection from encoder2
            nn.Mish(),
        )
        
        self.decoder1 = nn.Sequential(
            nn.Linear(hidden_dim * 2, action_dim),  # *2 due to skip connection from encoder1
        )

    def forward(self, x, t, s):
        """
        Forward pass of the UNet diffusion model.
        
        The UNet processes the input through an encoder-decoder architecture
        with skip connections that help preserve fine-grained information.
        
        Args:
            x (torch.Tensor): Noisy actions [batch_size, action_dim]
            t (torch.Tensor): Timesteps [batch_size]
            s (torch.Tensor): States [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Predicted noise or clean actions [batch_size, action_dim]
        """
        # Encode timestep
        t_emb = self.time_mlp(t.float())
        x_input = torch.cat([x, s, t_emb], dim=-1)
        
        # Encoder path - store features for skip connections
        e1 = self.encoder1(x_input)
        e2 = self.encoder2(e1)
        
        # Bottleneck processing
        b = self.bottleneck(e2)
        
        # Decoder path with skip connections
        # Skip connections help preserve information from earlier layers
        d2 = self.decoder2(torch.cat([b, e2], dim=-1))  # Skip connection from encoder2
        d1 = self.decoder1(torch.cat([d2, e1], dim=-1))  # Skip connection from encoder1
        
        return d1


class DiffusionPolicy(nn.Module):
    """
    Diffusion Policy model for imitation learning.
    
    This class implements a complete diffusion model for learning policies
    from expert demonstrations. It uses the denoising diffusion probabilistic
    model (DDPM) framework to generate actions conditioned on states.
    
    The diffusion process works in two phases:
    1. Forward process: Gradually adds noise to actions until they become pure noise
    2. Reverse process: Learns to denoise actions step by step to recover the original
    
    Mathematical foundation:
    - Forward: q(x_t | x_{t-1}) = N(x_t; √(1-β_t) x_{t-1}, β_t I)
    - Reverse: p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), Σ_θ(x_t, t))
    """
    
    def __init__(self, 
                 state_dim, 
                 action_dim, 
                 model_type='mlp',
                 max_action=1.0,
                 beta_schedule='cosine', 
                 n_timesteps=100,
                 clip_denoised=True, 
                 predict_epsilon=True,
                 hidden_dim=256,
                 time_dim=32):
        """
        Initialize the diffusion policy model.
        
        Args:
            state_dim (int): Dimension of the state space
            action_dim (int): Dimension of the action space
            model_type (str): Type of denoising model ('mlp' or 'unet')
            max_action (float): Maximum action value for clipping
            beta_schedule (str): Noise schedule type ('linear' or 'cosine')
            n_timesteps (int): Number of diffusion timesteps
            clip_denoised (bool): Whether to clip denoised actions
            predict_epsilon (bool): Whether to predict noise (True) or clean actions (False)
            hidden_dim (int): Hidden dimension for the denoising model
            time_dim (int): Time embedding dimension
        """
        super(DiffusionPolicy, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon

        # Choose model architecture for the denoising function
        # MLP is simpler and faster, UNet has skip connections for better performance
        if model_type.lower() == 'mlp':
            self.model = MLPDiffusion(state_dim, action_dim, time_dim, hidden_dim)
        elif model_type.lower() == 'unet':
            self.model = UNetDiffusion(state_dim, action_dim, time_dim, hidden_dim)
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        # Set up noise schedule (β_t values)
        # Linear schedule: constant increase in noise
        # Cosine schedule: smoother transition, often works better
        if beta_schedule == 'linear':
            betas = linear_beta_schedule(n_timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(n_timesteps)
        else:
            raise ValueError(f"Unknown beta_schedule: {beta_schedule}")

        # Compute α_t = 1 - β_t (noise retention factor)
        alphas = 1. - betas
        
        # Compute cumulative products α̅_t = ∏_{i=1}^t α_i
        # This is crucial for the forward process q(x_t | x_0)
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)

        # Register as buffers (not model parameters, but part of model state)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # Precompute constants for the forward process q(x_t | x_0)
        # x_t = √α̅_t * x_0 + √(1-α̅_t) * ε, where ε ~ N(0,I)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # Precompute constants for the posterior distribution q(x_{t-1} | x_t, x_0)
        # This posterior is used in the reverse process
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
                             betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

    def predict_start_from_noise(self, x_t, t, noise):
        """
        Predict the original clean action x_0 from noisy action x_t and predicted noise.
        
        This implements the reparameterization: x_0 = (x_t - √(1-α̅_t) * ε) / √α̅_t
        
        Args:
            x_t (torch.Tensor): Noisy action at timestep t
            t (torch.Tensor): Timestep
            noise (torch.Tensor): Predicted noise
            
        Returns:
            torch.Tensor: Predicted clean action x_0
        """
        if self.predict_epsilon:
            # Standard DDPM formulation: predict noise and recover x_0
            return (
                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            # Alternative formulation: directly predict x_0
            return noise

    def q_posterior(self, x_start, x_t, t):
        """
        Compute the posterior distribution q(x_{t-1} | x_t, x_0).
        
        This is the true posterior that we want our learned reverse process
        to approximate. It has a closed-form solution when x_0 is known.
        
        Args:
            x_start (torch.Tensor): Clean action x_0
            x_t (torch.Tensor): Noisy action at timestep t
            t (torch.Tensor): Timestep
            
        Returns:
            tuple: (posterior_mean, posterior_variance, posterior_log_variance_clipped)
        """
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, state):
        """
        Compute the mean and variance of the learned reverse process p_θ(x_{t-1} | x_t).
        
        This function combines the denoising model's prediction with the
        posterior computation to get the parameters for sampling x_{t-1}.
        
        Args:
            x (torch.Tensor): Current noisy action x_t
            t (torch.Tensor): Current timestep
            state (torch.Tensor): Conditioning state
            
        Returns:
            tuple: (model_mean, posterior_variance, posterior_log_variance)
        """
        # Use the denoising model to predict x_0 from x_t
        x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, t, state))

        # Clip the reconstructed action to valid range
        if self.clip_denoised:
            x_recon.clamp_(-self.max_action, self.max_action)

        # Compute posterior parameters using the predicted x_0
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, x, t, state):
        """
        Perform one step of reverse sampling: sample x_{t-1} from p_θ(x_{t-1} | x_t).
        
        This implements the reverse diffusion step that gradually removes noise
        from the action to recover the clean action.
        
        Args:
            x (torch.Tensor): Current noisy action x_t
            t (torch.Tensor): Current timestep
            state (torch.Tensor): Conditioning state
            
        Returns:
            torch.Tensor: Less noisy action x_{t-1}
        """
        b, *_, device = *x.shape, x.device
        
        # Get the predicted mean and variance for this step
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, state=state)
        
        # Sample noise for stochastic sampling (except at t=0)
        noise = torch.randn_like(x)
        
        # Don't add noise when t == 0 (final step should be deterministic)
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        
        # Sample from the learned distribution: x_{t-1} ~ N(μ_θ, Σ_θ)
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    def p_sample_loop(self, state, shape):
        """
        Complete reverse sampling loop to generate actions from pure noise.
        
        This function runs the full reverse diffusion process, starting from
        pure Gaussian noise and iteratively denoising to produce a clean action.
        
        Args:
            state (torch.Tensor): Conditioning state
            shape (tuple): Shape of the action to generate
            
        Returns:
            torch.Tensor: Generated clean action
        """
        device = self.betas.device
        batch_size = shape[0]
        
        # Start from pure noise: x_T ~ N(0, I)
        x = torch.randn(shape, device=device)

        # Iteratively denoise: x_T → x_{T-1} → ... → x_1 → x_0
        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, timesteps, state)

        return x

    def sample(self, state):
        """
        Sample actions from the diffusion policy given states.
        
        This is the main interface for generating actions during inference.
        
        Args:
            state (torch.Tensor): Input states [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Generated actions [batch_size, action_dim]
        """
        batch_size = state.shape[0]
        shape = (batch_size, self.action_dim)
        
        # Run the complete sampling process
        action = self.p_sample_loop(state, shape)
        
        # Ensure actions are within valid bounds
        return action.clamp_(-self.max_action, self.max_action)

    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion process: add noise to clean actions.
        
        This implements q(x_t | x_0) = N(x_t; √α̅_t * x_0, (1-α̅_t) * I)
        Used for training to create noisy examples.
        
        Args:
            x_start (torch.Tensor): Clean actions x_0
            t (torch.Tensor): Timesteps
            noise (torch.Tensor, optional): Noise to add (generated if None)
            
        Returns:
            torch.Tensor: Noisy actions x_t
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        # Apply the forward diffusion equation
        sample = (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
        return sample

    def p_losses(self, x_start, state, t):
        """
        Compute the diffusion training loss for a batch of actions.
        
        The loss function depends on the prediction target:
        - If predict_epsilon=True: L = ||ε - ε_θ(x_t, t)||²
        - If predict_epsilon=False: L = ||x_0 - x_θ(x_t, t)||²
        
        Args:
            x_start (torch.Tensor): Clean expert actions
            state (torch.Tensor): Conditioning states
            t (torch.Tensor): Random timesteps
            
        Returns:
            torch.Tensor: Training loss
        """
        # Generate random noise
        noise = torch.randn_like(x_start)
        
        # Create noisy version of the action at timestep t
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        
        # Predict noise or clean action using the denoising model
        predicted_noise = self.model(x_noisy, t, state)

        # Compute loss based on prediction target
        if self.predict_epsilon:
            # Standard DDPM objective: predict the noise that was added
            loss = F.mse_loss(predicted_noise, noise)
        else:
            # Alternative objective: directly predict the clean action
            loss = F.mse_loss(predicted_noise, x_start)

        return loss

    def loss(self, action, state):
        """
        Compute the training loss for a batch of state-action pairs.
        
        This function randomly samples timesteps and computes the diffusion loss,
        which trains the model to denoise actions at various noise levels.
        
        Args:
            action (torch.Tensor): Expert actions [batch_size, action_dim]
            state (torch.Tensor): States [batch_size, state_dim]
            
        Returns:
            torch.Tensor: Average loss for the batch
        """
        batch_size = len(action)
        
        # Randomly sample timesteps for each example in the batch
        # This creates a diverse training signal across all noise levels
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=action.device).long()
        
        # Compute the diffusion loss
        return self.p_losses(action, state, t)

    def forward(self, state):
        """
        Forward pass: generate actions from states.
        
        Args:
            state (torch.Tensor): Input states
            
        Returns:
            torch.Tensor: Generated actions
        """
        return self.sample(state)
    
"""Model architectures and preconditioning schemes used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

class IMMDiffusion(torch.nn.Module):
    """
    IMM Diffusion Model for Conditional Action Generation
    
    This class combines:
    - IMMPrecond: Preconditioning and sampling scheme
    - IMMLoss: Maximum Mean Discrepancy-based training loss
    - CFG: Optional Classifier-Free Guidance for inference
    
    Structure follows reference Diffusion/Consistency models:
    - loss(action, state, ...): Compute training loss
    - sample(state, ...): Generate actions from states
    - forward(state): Wrapper calling sample()
    """
    
    def __init__(
        self,
        model,
        state_dim,
        action_dim,
        max_action,
        device="cuda" if torch.cuda.is_available() else "cpu",
        # IMMPrecond parameters
        noise_schedule="fm",
        sigma_data=0.5, 
        f_type="euler_fm",
        T=0.994,
        eps=0.001,  
        temb_type='identity', 
        time_scale=1000.,
        # IMMLoss parameters
        mmd_sigma=1, 
        sample_t_mode="lognormal",
        P_mean=-1.1,
        P_std=2.0, 
        matrix_size=16, 
        sample_repeat=1,
        k=12,
        a=2,
        b=4, 
        min_tr_gap=None,
        # CFG parameters
        cfg_scale=None,
        **model_kwargs,
    ):
        super().__init__()
        
        # Basic dimensions and device
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.device = device
        
        # Noise schedule parameters
        self.noise_schedule = noise_schedule
        self.T = T
        self.eps = eps
        self.sigma_data = sigma_data
        self.f_type = f_type
        
        # Compute noise bounds
        self.nt_low = self.get_log_nt(torch.tensor(self.eps, dtype=torch.float64)).exp().numpy().item()
        self.nt_high = self.get_log_nt(torch.tensor(self.T, dtype=torch.float64)).exp().numpy().item()
        
        self.register_buffer('nt_low_tensor', torch.tensor(self.nt_low, dtype=torch.float64))
        self.register_buffer('nt_high_tensor', torch.tensor(self.nt_high, dtype=torch.float64))
        
        # Build denoising network
        assert model.lower() == "mlp" or model.lower() == "unet"
        if model.lower() == "mlp":
            model = VectorMLP
        elif model.lower() == "unet":
            model = VectorUNet
        else:
            raise ValueError(f"Unknown model_type: {model}")
            
        self.model = model(state_dim, action_dim, **model_kwargs)
        
        print('# Mparams:', sum(p.numel() for p in self.model.parameters()) / 1000000)
        
        # Time embedding parameters
        self.time_scale = time_scale 
        self.temb_type = temb_type
        
        if self.f_type == 'euler_fm':
            assert self.noise_schedule == 'fm'
        
        # MMD loss parameters
        self.P_mean = P_mean
        self.P_std = P_std
        self.min_tr_gap = min_tr_gap
        self.mmd_sigma = mmd_sigma
        self.sample_t_mode = sample_t_mode 
        self.matrix_size = matrix_size
        self.sample_repeat = sample_repeat
        self.a = a
        self.b = b
        self.k = k
        
        # Classifier-Free Guidance
        self.cfg_scale = cfg_scale

    # ============================================================================
    # Noise Schedule and Time Transformation
    # ============================================================================
    
    def get_logsnr(self, t):
        """Compute log signal-to-noise ratio for given timestep t"""
        dtype = t.dtype
        t = t.to(torch.float64)
        if self.noise_schedule == "vp_cosine":
            logsnr = -2 * torch.log(torch.tan(t * torch.pi * 0.5))
        elif self.noise_schedule == "fm":
            logsnr = 2 * ((1 - t).log() - t.log())
        logsnr = logsnr.to(dtype)
        return logsnr
    
    def get_log_nt(self, t):
        """Convert timestep to log noise level"""
        logsnr_t = self.get_logsnr(t)
        return -0.5 * logsnr_t
    
    def get_alpha_sigma(self, t):
        """Get noise schedule coefficients alpha_t and sigma_t"""
        if self.noise_schedule == 'fm':
            alpha_t = (1 - t)
            sigma_t = t
        elif self.noise_schedule == 'vp_cosine': 
            alpha_t = torch.cos(t * torch.pi * 0.5)
            sigma_t = torch.sin(t * torch.pi * 0.5)
        return alpha_t, sigma_t
    
    def nt_to_t(self, nt):
        """Convert noise level to timestep"""
        dtype = nt.dtype
        nt = nt.to(torch.float64)
        if self.noise_schedule == "vp_cosine":
            t = torch.arctan(nt) / (torch.pi * 0.5) 
        elif self.noise_schedule == "fm":
            t = nt / (1 + nt)
        t = torch.nan_to_num(t, nan=1)
        t = t.to(dtype)
        if (self.noise_schedule.startswith("vp") and self.noise_schedule == "fm" and t.max() > 1):
            raise ValueError(f"t out of range: {t.min().item()}, {t.max().item()}")
        return t

    def get_logsnr_prime(self, t):
        """Derivative of log-SNR with respect to t"""
        if self.noise_schedule == "vp_cosine":
            return -1 * torch.pi / (torch.sin(t * torch.pi * 0.5) * torch.cos(t * torch.pi * 0.5))
        elif self.noise_schedule == "fm":
            return -2 * (1 / (1 - t) / t)

    # ============================================================================
    # Forward and Reverse Diffusion Process
    # ============================================================================

    def add_noise(self, y, t, noise=None):
        """Forward diffusion: add noise to clean action"""
        if noise is None:
            noise = torch.randn_like(y) * self.sigma_data
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        return alpha_t * y + sigma_t * noise, noise 

    def ddim(self, yt, y, t, s, noise=None):
        """DDIM reverse step from t to s"""
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        alpha_s, sigma_s = self.get_alpha_sigma(s)
        if noise is None: 
            ys = (alpha_s - alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt
        else:
            ys = alpha_s * y + sigma_s * noise
        return ys

    def simple_edm_sample_function(self, yt, y, t, s):
        """EDM-style sampling step from t to s"""
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        alpha_s, sigma_s = self.get_alpha_sigma(s)
        c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2)
        c_out = - (alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data
        return c_skip * yt + c_out * y
    
    def euler_fm_sample_function(self, yt, y, t, s):
        """Euler method for flow matching"""
        assert self.noise_schedule == 'fm'   
        return yt - (t - s) * self.sigma_data * y 

    def get_init_noise(self, shape, device):
        """Initialize noise for sampling"""
        noise = torch.randn(shape, device=device) * self.sigma_data
        return noise

    # ============================================================================
    # Denoising Network (Internal)
    # ============================================================================

    def compute_Fx(self, x, t, s, state, **model_kwargs):
        """Compute network output with preconditioning"""
        alpha_t, sigma_t = self.get_alpha_sigma(t)
        c_in = (alpha_t ** 2 + sigma_t**2 ).rsqrt() / self.sigma_data  
        if self.temb_type == 'identity': 
            c_noise_t = t * self.time_scale
            c_noise_s = s * self.time_scale
        elif self.temb_type == 'stride':
            c_noise_t = t * self.time_scale
            c_noise_s = (t - s) * self.time_scale
        F_x = self.model((c_in * x), c_noise_t.flatten(), c_noise_s.flatten(), state=state, **model_kwargs)
        return F_x

    def _denoise_step(self, x, t, s=None, state=None, **model_kwargs):
        """
        Single denoising step from time t to time s.
        Internal method - use sample() for complete sampling.
        
        Args:
            x: noisy action at time t
            t: current timestep
            s: target timestep (s < t)
            state: conditioning state
            
        Returns:
            denoised action at time s
        """
        dtype = t.dtype  
        F_x = self.compute_Fx(x.to(torch.float32), t.to(torch.float32), s.to(torch.float32), state=state, **model_kwargs).to(dtype) 
        if self.f_type == "identity":
            F_x = self.ddim(x, F_x, t, s)  
        elif self.f_type == "simple_edm": 
            F_x = self.simple_edm_sample_function(x, F_x , t, s)   
        elif self.f_type == "euler_fm": 
            F_x = self.euler_fm_sample_function(x, F_x, t, s)  
        else:
            raise NotImplementedError
        return F_x

    # ============================================================================
    # Sampling (Inference)
    # ============================================================================
    
    def sample(self, state, num_steps=16, clip_action=True, cfg_scale=None, 
               discretization='edm'):
        """
        Sample actions from states using iterative denoising.
        Similar to Diffusion.sample() and Consistency.sample() in reference.
        
        Args:
            state: conditioning state [batch, state_dim]
            num_steps: number of denoising steps (default: 8, ignored if discretization=None)
            clip_action: whether to clip actions to max_action (default: True)
            cfg_scale: Classifier-Free Guidance scale. If None, uses self.cfg_scale
            discretization: timestep discretization strategy
                - 'uniform': uniform spacing in time
                - 'edm': EDM-style non-uniform spacing (Karras et al.)
                - None: custom timesteps via mid_nt parameter
        
        Returns:
            sampled actions [batch, action_dim]
            
        Note:
            Gradients are automatically tracked when called in training mode.
            No need to explicitly set requires_grad on initial noise - gradients
            will automatically flow from model parameters after the first step.
        """
        batch_size = state.shape[0]
        
        # Time step discretization
        if discretization == 'uniform':
            timesteps = torch.linspace(self.T, self.eps, num_steps + 1, device=self.device)
        elif discretization == 'edm':
            # EDM-style non-uniform discretization (Karras et al.)
            nt_min = self.get_log_nt(torch.tensor(self.eps, dtype=torch.float64)).exp().item()
            nt_max = self.get_log_nt(torch.tensor(self.T, dtype=torch.float64)).exp().item()
            rho = 7  # Optimal value from EDM paper
            
            step_indices = torch.arange(num_steps + 1, dtype=torch.float64, device=self.device)
            nt_steps = (nt_max**(1/rho) + step_indices/num_steps * 
                       (nt_min**(1/rho) - nt_max**(1/rho)))**rho
            timesteps = self.nt_to_t(nt_steps).to(torch.float32)
        elif discretization is None:
            # Custom timesteps via manually specified noise levels
            if mid_nt is None:
                mid_nt = []
            # Convert noise levels to time steps
            mid_t = [self.nt_to_t(torch.as_tensor(nt, dtype=torch.float64, device=self.device)).item() 
                     for nt in mid_nt]
            # Construct timesteps: [T, mid_t_1, mid_t_2, ..., mid_t_k, eps]
            timesteps = torch.tensor([self.T] + list(mid_t) + [self.eps], 
                                    dtype=torch.float32, device=self.device)
        else:
            raise ValueError(f"Unknown discretization: {discretization}")
        
        # Initialize from noise (no need for requires_grad=True)
        # After first model call, output will have gradients from model parameters
        current_action = torch.randn((batch_size, self.action_dim), device=self.device) * self.max_action
        
        # Determine guidance scale
        guidance_scale = cfg_scale if cfg_scale is not None else None
        
        # Restart sampling: denoise to near-clean, then add noise and repeat
        for i in range(num_steps):
            t = timesteps[i]
            t_tensor = torch.full((batch_size, 1), t, device=self.device)
            eps_tensor = torch.full((batch_size, 1), self.eps, device=self.device)
            
            # Denoise directly to near-clean (t → eps)
            if (guidance_scale is not None) and (guidance_scale > 0):
                action_cfg = torch.cat([current_action, current_action], dim=0)
                uncond_state = torch.zeros_like(state)
                state_cfg = torch.cat([uncond_state, state], dim=0)
                t_cfg = torch.cat([t_tensor, t_tensor], dim=0)
                eps_cfg = torch.cat([eps_tensor, eps_tensor], dim=0)
                
                output_cfg = self._denoise_step(x=action_cfg, t=t_cfg, s=eps_cfg, state=state_cfg)
                uncond_output = output_cfg[:batch_size]
                cond_output = output_cfg[batch_size:]
                current_action = uncond_output + guidance_scale * (cond_output - uncond_output)
            else:
                current_action = self._denoise_step(x=current_action, t=t_tensor, s=eps_tensor, state=state)
            
            # Add noise for next iteration (except last step)
            if i < num_steps - 1:
                next_t = timesteps[i + 1]
                current_action, _ = self.add_noise(current_action, 
                                                    torch.full((batch_size, 1), next_t, device=self.device))
        
        action = current_action
        return action.clamp_(-self.max_action, self.max_action) if clip_action else action

    # ============================================================================
    # Training Loss (MMD-based)
    # ============================================================================
    
    def get_kernel_weight(self, t, s, a=1, b=0):
        """Compute kernel weights for MMD loss"""
        logsnr_t = self.get_logsnr(t)
        alpha_t, sigma_t = self.get_alpha_sigma(t) 
        alpha_s, sigma_s = self.get_alpha_sigma(s) 
        if self.f_type == 'identity':
            w = (sigma_t / (alpha_s * sigma_t - sigma_s * alpha_t )).abs()
        elif self.f_type == 'simple_edm':
            w = (alpha_t**2 + sigma_t**2).sqrt() / (alpha_s * sigma_t - alpha_t * sigma_s).abs() / self.sigma_data
        elif self.f_type == 'euler_fm':
            w = 1 / (t - s).abs() / self.sigma_data
        else:
            raise NotImplementedError 
        neg_dlogsnr_dt = - self.get_logsnr_prime(t)
        wout = alpha_t ** a / (alpha_t**2 + sigma_t**2) * 0.5 * neg_dlogsnr_dt * (b - logsnr_t).sigmoid()
        return w, wout

    def sample_eta_t(self, batch_size, device, log_low=None, log_high=None, low=None, high=None, sample_mode="lognormal", P_mean=-1.1, P_std=2.0, **kwargs):
        """Sample noise levels for training"""
        if sample_mode == "lognormal":
            log_low = log_low if log_low is not None else -float("inf")
            log_low = torch.as_tensor(log_low, device=device, dtype=torch.float64)
            if log_low.ndim == 0:
                log_low = log_low.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)
            log_high = log_high if log_high is not None else float("inf")
            log_high = torch.as_tensor(log_high, device=device, dtype=torch.float64)
            if log_high.ndim == 0:
                log_high = log_high.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)
            dist = torch.distributions.Normal(loc=torch.full_like(log_high, P_mean), scale=torch.full_like(log_high, P_std))
            cdf = torch.rand([batch_size, 1, 1, 1], device=device) * (dist.cdf(log_high) - dist.cdf(log_low)) + dist.cdf(log_low)
            log_nt = dist.icdf(cdf)
            nt = log_nt.exp()
        elif sample_mode == "uniform":
            if high is None:
                high = self.nt_high
            else:
                high = torch.as_tensor(high, device=device, dtype=torch.float64)
                if high.ndim == 0:
                    high = high.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)
            if low is None:
                low = self.nt_low
            else:
                low = torch.as_tensor(low, device=device, dtype=torch.float64)
                if low.ndim == 0:
                    low = low.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)
            high_t = self.nt_to_t(high)
            low_t = self.nt_to_t(low)
            t = (torch.rand([batch_size, 1, 1, 1], device=device, dtype=torch.float64) * (high_t - low_t) + low_t)
            log_nt = self.get_log_nt(t)
            nt = log_nt.exp()
        else:
            raise ValueError(f"Unknown sample_t_mode: {sample_mode}")
        return nt, log_nt

    def nt_to_nr(self, nt):
        """Convert noise level nt to reference level nr"""
        u = (self.nt_high - self.nt_low) * (1/ 2) ** self.k
        nr = (nt - u).clamp(min=self.nt_low, max=self.nt_high) 
        return nr
    
    def sample_trs(self, t_bs, device):
        """Sample timestep triplet (t, r, s) for MMD training"""
        high = self.nt_high 
        low = self.nt_low 
        nt, log_nt = self.sample_eta_t(t_bs, device, log_low=np.log(low), log_high=np.log(high), low=low, high=high, sample_mode=self.sample_t_mode, P_mean=self.P_mean, P_std=self.P_std) 
        ns_upper = nt
        logns_upper = log_nt
        ns, log_ns = self.sample_eta_t(t_bs, device, log_low=np.log(self.nt_low), log_high=logns_upper, low=self.nt_low, high=ns_upper, sample_mode=self.sample_t_mode, P_mean=self.P_mean, P_std=self.P_std) 
        ns = torch.minimum(ns, nt).clamp(min=self.nt_low)
        nr = self.nt_to_nr(nt) 
        t = self.nt_to_t(nt) 
        r = self.nt_to_t(nr)
        s = self.nt_to_t(ns)
        assert torch.allclose(self.get_log_nt(t).exp(), nt)
        assert torch.allclose(self.get_log_nt(r).exp(), nr)
        assert torch.allclose(self.get_log_nt(s).exp(), ns)
        if self.min_tr_gap is not None: 
            max_r = torch.clamp(t - self.min_tr_gap, min=self.nt_low)
            r = torch.minimum(r, max_r) 
        r = torch.maximum(r, s).clamp(min=self.eps)  
        return t, r, s
    
    def kernel_fn(self, x, y, flatten_dim, w):
        """Compute RBF kernel between two samples"""
        loss = (torch.clamp_min(((x - y) ** 2).flatten(flatten_dim).sum(-1), 1e-10)).sqrt() / (np.prod(y.shape[flatten_dim:])) / self.mmd_sigma
        ret = torch.exp(-loss * w) 
        return ret 
        
    def kernel(self, x, y, w=None):
        """Compute kernel matrix for MMD"""
        x = x.unsqueeze(2)
        y = y.unsqueeze(1)
        if w is None:
            w = 1
        else:
            w = w[:, None, None]
        ret = self.kernel_fn(x, y, flatten_dim=3, w=w)
        return ret
      
    def get_mmd_loss(self, f_st, f_sr, w, wout):
        """Compute MMD loss between distributions"""
        inter_sample_sim = self.kernel(f_st, f_st, w=w) 
        inter_gt_sim = self.kernel(f_sr, f_sr, w=w)  
        cross_sim = self.kernel(f_st, f_sr, w=w) 
        inter_sample_sim = inter_sample_sim.mean((1, 2)) 
        cross_sim = cross_sim.mean((1, 2))
        inter_gt_sim = inter_gt_sim.mean((1, 2))
        loss = inter_sample_sim + inter_gt_sim - 2 * cross_sim 
        #import ipdb; ipdb.set_trace()
        if wout is not None:
            loss = wout * loss
        logs = {"inter_sample_sim": inter_sample_sim.detach(), "inter_gt_sim": inter_gt_sim.detach(), "cross_sim": cross_sim.detach()}
        return loss.mean(), logs
    
    def loss(self, action, state, device=None, cfg_dropout_prob=0.0, **kwargs):
        """
        Compute training loss (MMD-based).
        Similar to Diffusion.loss() and Consistency.loss() in reference.
        
        Args:
            action: expert actions [batch, action_dim]
            state: conditioning states [batch, state_dim]
            device: computation device (default: self.device)
            cfg_dropout_prob: probability of dropping state for CFG training
            
        Returns:
            loss: scalar loss value
            logs: dict of logging metrics
            
        Note:
            This computes MMD loss between predicted and target denoising directions
            at randomly sampled timesteps (t, r, s).
        """
        if device is None:
            device = self.device
        state = state.to(device)
        
        # CFG training: randomly mask state to learn unconditional generation
        if (self.cfg_scale is not None) and cfg_dropout_prob > 0 and self.training:
            # Create mask: 1 = keep state, 0 = drop state
            mask = (torch.rand(state.shape[0], 1, device=device) > cfg_dropout_prob).float()
            state = state * mask  # Zero out dropped states
        
        current_matrix_size = self.matrix_size 
        t, r, s = self.sample_trs(state.shape[0] // current_matrix_size, device=state.device)
        #loss_scale = (t - r + 1e-6) / (r - s + 1e-6)
        #print("t:", t)
        #print("r:", r)
        #print("s:", s)
        #print("loss_scale:", loss_scale)
        t = t.repeat_interleave(current_matrix_size, dim=0)
        s = s.repeat_interleave(current_matrix_size, dim=0)
        r = r.repeat_interleave(current_matrix_size, dim=0)   
        action = action.repeat_interleave(self.sample_repeat, dim=0)
        t_flat = t.flatten()
        s_flat = s.flatten()
        r_flat = r.flatten()
        yt, noise_t = self.add_noise(action, t_flat.unsqueeze(1))
        yr = self.ddim(yt, action, t_flat.unsqueeze(1), r_flat.unsqueeze(1), noise=noise_t)
        rng_state = torch.cuda.get_rng_state()
        f_st = self._denoise_step(yt, t_flat.unsqueeze(1), s_flat.unsqueeze(1), state=state) 
        torch.cuda.set_rng_state(rng_state) 
        with torch.no_grad():   
            f_sr = self._denoise_step(yr, r_flat.unsqueeze(1), s_flat.unsqueeze(1), state=state) 
        f_st = rearrange(f_st, "(b m) ... -> b m ...", m=current_matrix_size)
        f_sr = rearrange(f_sr, "(b m) ... -> b m ...", m=current_matrix_size) 
        yt = rearrange(yt, "(b m) ... -> b m ...", m=current_matrix_size)
        yr = rearrange(yr, "(b m) ... -> b m ...", m=current_matrix_size)
        t = rearrange(t, "(b m) ... -> b m ...", m=current_matrix_size)
        r = rearrange(r, "(b m) ... -> b m ...", m=current_matrix_size) 
        s = rearrange(s, "(b m) ... -> b m ...", m=current_matrix_size)
        wt, wtout = self.get_kernel_weight(t[:, 0].flatten(), s[:, 0].flatten(), a=self.a, b=self.b)    
        #import ipdb; ipdb.set_trace()
        loss_value, loss_logs = self.get_mmd_loss(f_st, f_sr, wt, wtout)
        logs = {"r_t_ratio": r[:, 0] / t[:, 0], "s_t_ratio": s[:, 0] / t[:, 0], "t_r_diff": t[:, 0] - r[:, 0], 'loss': loss_value, **loss_logs} 
        if torch.isnan(loss_value).any():
            print("Nan in loss")
            loss_value = torch.nan_to_num(loss_value) 
        logs["ts"] = t[:, 0].flatten()  
        return loss_value, logs

    # ============================================================================
    # Forward Pass (Main Interface)
    # ============================================================================
    
    def forward(self, state, *args, **kwargs):
        """
        Forward pass: generate actions from states.
        Follows the same interface as Diffusion.forward() and Consistency.forward().
        
        Args:
            state: input states [batch, state_dim]
            *args, **kwargs: additional arguments passed to sample()
            
        Returns:
            generated actions [batch, action_dim]
        """
        return self.sample(state, *args, **kwargs)

# class IMMPrecond(torch.nn.Module):

#     def __init__(
#         self,
#         model,
#         state_dim,
#         action_dim,
#         max_action,
#         device = "cuda" if torch.cuda.is_available() else "cpu",
#         noise_schedule="fm",
#         sigma_data=0.5, 
#         f_type="euler_fm",
#         T=0.994,
#         eps=0.001,  
#         temb_type='identity', 
#         time_scale=1000.,  
#         **model_kwargs,  # Keyword arguments for the underlying model.
#     ):
#         super().__init__()
#         self.state_dim = state_dim
#         self.action_dim = action_dim
#         self.max_action = max_action
#         self.device = device
#         self.noise_schedule = noise_schedule
#         self.T = T
#         self.eps = eps
#         self.sigma_data = sigma_data
#         self.f_type = f_type

        
#         self.nt_low = self.get_log_nt(torch.tensor(self.eps, dtype=torch.float64)).exp().numpy().item()
#         self.nt_high = self.get_log_nt(torch.tensor(self.T, dtype=torch.float64)).exp().numpy().item()
        
#         # Add the buffer registrations here:
#         self.register_buffer('nt_low_tensor', torch.tensor(self.nt_low, dtype=torch.float64))
#         self.register_buffer('nt_high_tensor', torch.tensor(self.nt_high, dtype=torch.float64))
        
#         assert model.lower() == "mlp" or model.lower() == "unet"
#         if model.lower() == "mlp":
#             model = VectorMLP
#         elif model.lower() == "unet":
#             model = VectorUNet
#         else:
#             raise ValueError(f"Unknown model_type: {model}")
            
#         self.model = model(
#             state_dim,
#             action_dim,
#             **model_kwargs
#         )
        
#         print('# Mparams:', sum(p.numel() for p in self.model.parameters()) / 1000000)
        
        
#         self.time_scale = time_scale 
#         self.temb_type = temb_type
        
#         if self.f_type == 'euler_fm':
#             assert self.noise_schedule == 'fm'
          

#     def get_logsnr(self, t):
#         dtype = t.dtype
#         t = t.to(torch.float64)
#         if self.noise_schedule == "vp_cosine":
#             logsnr = -2 * torch.log(torch.tan(t * torch.pi * 0.5))
 
#         elif self.noise_schedule == "fm":
            
#             logsnr = 2 * ((1 - t).log() - t.log())
            
#         logsnr = logsnr.to(dtype)
#         return logsnr
    
#     def get_log_nt(self, t):
#         logsnr_t = self.get_logsnr(t)
#         return -0.5 * logsnr_t
    
#     def get_alpha_sigma(self, t): 
#         if self.noise_schedule == 'fm':
#             alpha_t = (1 - t)
#             sigma_t = t
#         elif self.noise_schedule == 'vp_cosine': 
#             alpha_t = torch.cos(t * torch.pi * 0.5)
#             sigma_t = torch.sin(t * torch.pi * 0.5)
            
#         return alpha_t, sigma_t 

#     def add_noise(self, y, t, noise=None):

#         if noise is None:
#             noise = torch.randn_like(y) * self.sigma_data

#         alpha_t, sigma_t = self.get_alpha_sigma(t)
         
#         return alpha_t * y + sigma_t * noise, noise 

#     def ddim(self, yt, y, t, s, noise=None):
#         alpha_t, sigma_t = self.get_alpha_sigma(t)
#         alpha_s, sigma_s = self.get_alpha_sigma(s)
        

#         if noise is None: 
            
#             ys = (alpha_s -   alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt
#         else:
#             ys = alpha_s * y + sigma_s * noise
#         return ys
  
   

#     def simple_edm_sample_function(self, yt, y, t, s):
#         alpha_t, sigma_t = self.get_alpha_sigma(t)
#         alpha_s, sigma_s = self.get_alpha_sigma(s)
         
#         c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2)

#         c_out = - (alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data
        
#         return c_skip * yt + c_out * y
    
#     def euler_fm_sample_function(self, yt, y, t, s):
#         assert self.noise_schedule == 'fm'   
#         return  yt - (t - s) * self.sigma_data *  y 
          
#     def nt_to_t(self, nt):
#         dtype = nt.dtype
#         nt = nt.to(torch.float64)
#         if self.noise_schedule == "vp_cosine":
#             t = torch.arctan(nt) / (torch.pi * 0.5) 
 
#         elif self.noise_schedule == "fm":
#             t = nt / (1 + nt)
            
#         t = torch.nan_to_num(t, nan=1)

#         t = t.to(dtype)
            

        
#         if (
#             self.noise_schedule.startswith("vp")
#             and self.noise_schedule == "fm"
#             and t.max() > 1
#         ):
#             raise ValueError(f"t out of range: {t.min().item()}, {t.max().item()}")
#         return t

#     def get_init_noise(self, shape, device):
        
#         noise = torch.randn(shape, device=device) * self.sigma_data
#         return noise

#     def compute_Fx(
#         self,
#         x,
#         t,
#         s,
#         state,
#         **model_kwargs,
#     ):
        
#         alpha_t, sigma_t = self.get_alpha_sigma(t)
    
#         c_in = (alpha_t ** 2 + sigma_t**2 ).rsqrt() / self.sigma_data  
#         if self.temb_type == 'identity': 

#             c_noise_t = t  * self.time_scale
#             c_noise_s = s  * self.time_scale
            
#         elif self.temb_type == 'stride':

#             c_noise_t = t * self.time_scale
#             c_noise_s = (t - s) * self.time_scale
            
#         F_x = self.model( 
#             (c_in * x),
#             c_noise_t.flatten(),
#             c_noise_s.flatten(),
#             state=state,
#             **model_kwargs)
        
#         return F_x

    
#     def forward(
#         self,
#         x,
#         t,
#         s=None,
#         state=None,
#         **model_kwargs,
#     ):
#         dtype = t.dtype  
            
#         F_x = self.compute_Fx(
#             x.to(torch.float32),
#             t.to(torch.float32),
            
            
            
            
#             s.to(torch.float32),
#             state=state,
#             **model_kwargs
#         ).to(dtype) 
         
#         if self.f_type == "identity":
#             F_x  =  self.ddim(x, F_x, t, s)  
#         elif self.f_type == "simple_edm": 
#             F_x = self.simple_edm_sample_function(x, F_x , t, s)   
#         elif self.f_type == "euler_fm": 
#             F_x = self.euler_fm_sample_function(x, F_x, t, s)  
#         else:
#             raise NotImplementedError
 
#         return F_x


#     ######################################################### for training


#     def get_logsnr_prime(self, t):
#         if self.noise_schedule == "vp_cosine":
#             return (
#                 -1
#                 * torch.pi
#                 / (torch.sin(t * torch.pi * 0.5) * torch.cos(t * torch.pi * 0.5))
#             )
 
#         elif self.noise_schedule == "fm":
            
#             return -2 * (1 / (1 - t) / t)

            
#     def get_kernel_weight(self, t, s, a=1, b=0): 
             
#         logsnr_t = self.get_logsnr(t)

#         alpha_t, sigma_t = self.get_alpha_sigma(t) 
#         alpha_s, sigma_s = self.get_alpha_sigma(s) 
         

#         if self.f_type == 'identity':
#             w =   (sigma_t / (alpha_s * sigma_t - sigma_s * alpha_t )).abs()
#         elif self.f_type == 'simple_edm':
#             w =  (alpha_t**2 + sigma_t**2).sqrt() / (alpha_s * sigma_t - alpha_t * sigma_s).abs() / self.sigma_data
#         elif self.f_type == 'euler_fm':
#             w = 1 / (t - s).abs() / self.sigma_data
#         else:
#             raise NotImplementedError 


#         neg_dlogsnr_dt = - self.get_logsnr_prime(t)
                
#         wout =  alpha_t ** a / (alpha_t**2 + sigma_t**2)  * 0.5 * neg_dlogsnr_dt * (b - logsnr_t).sigmoid()
         
#         return w, wout

    
#     def sample_eta_t(
#         self,
#         batch_size,
#         device,
#         log_low=None,
#         log_high=None,
#         low=None,
#         high=None,
#         sample_mode="lognormal",
#         P_mean=-1.1,
#         P_std=2.0,
#         **kwargs,
#     ):
#         if sample_mode == "lognormal":

#             log_low = log_low if log_low is not None else -float("inf")
#             log_low = torch.as_tensor(log_low, device=device, dtype=torch.float64)
#             if log_low.ndim == 0:
#                 log_low = log_low.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)

#             log_high = log_high if log_high is not None else float("inf")
#             log_high = torch.as_tensor(log_high, device=device, dtype=torch.float64)
#             if log_high.ndim == 0:
#                 log_high = log_high.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)
#             dist = torch.distributions.Normal(
#                 loc=torch.full_like(log_high, P_mean),
#                 scale=torch.full_like(log_high, P_std),
#             )

#             cdf = torch.rand([batch_size, 1, 1, 1], device=device) * (
#                 dist.cdf(log_high) - dist.cdf(log_low)
#             ) + dist.cdf(log_low)

#             log_nt = dist.icdf(cdf)
#             nt = log_nt.exp()

#         elif sample_mode == "uniform":
#             if high is None:
#                 high = self.nt_high
#             else:
#                 high = torch.as_tensor(high, device=device, dtype=torch.float64)
#                 if high.ndim == 0:
#                     high = high.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)

#             if low is None:
#                 low = self.nt_low
#             else:
#                 low = torch.as_tensor(low, device=device, dtype=torch.float64)
#                 if low.ndim == 0:
#                     low = low.unsqueeze(0).expand(batch_size).reshape(-1, 1, 1, 1)

#             high_t = self.nt_to_t(high)
#             low_t = self.nt_to_t(low)

#             t = (
#                 torch.rand([batch_size, 1, 1, 1], device=device, dtype=torch.float64) * (high_t - low_t)
#                 + low_t
#             )

#             log_nt = self.get_log_nt(t)
#             nt = log_nt.exp()
#         else:
#             raise ValueError(f"Unknown sample_t_mode: {sample_mode}")

#         return nt, log_nt
    
    
    
#     @torch.no_grad()
#     def sample_action(self, state, num_steps=8, clip_action=True):
#         """
#         从给定的状态 state 生成一个动作 action。
#         这是一个完整的反向采样循环。

#         :param state: 输入的状态张量，形状为 (batch_size, state_dim)。
#         :param num_steps: 反向采样的步数。步数越多，质量越高但速度越慢。
#         :param clip_action: 是否将最终动作裁剪到 max_action 范围。
#         :return: 生成的动作张量，形状为 (batch_size, action_dim)。
#         """
#         batch_size = state.shape[0]
        
#         # 1. 定义时间序列，从 T 到 eps
#         timesteps = torch.linspace(self.T, self.eps, num_steps + 1, device=self.device)
        
#         # 2. 初始化动作，从纯噪声开始
#         # 形状为 (batch_size, action_dim)
#         current_action = torch.randn(
#             (batch_size, self.action_dim), 
#             device=self.device
#         ) * self.sigma_data # 使用 sigma_data 初始化噪声尺度

#         # 3. 迭代去噪循环
#         for i in range(num_steps):
#             t = timesteps[i]
#             s = timesteps[i + 1]
            
#             # 将标量时间 t 和 s 扩展到与批次大小匹配，保持2D形状(batch_size, 1)
#             t_tensor = torch.full((batch_size, 1), t, device=self.device)
#             s_tensor = torch.full((batch_size, 1), s, device=self.device)

#             # 4. 调用 forward 方法执行单步去噪
#             # 将 state 作为关键字参数传递
#             current_action = self.forward(
#                 x=current_action, 
#                 t=t_tensor, 
#                 s=s_tensor,
#                 state=state
#             )

#         # 循环结束后，current_action 就是去噪后的干净动作
        
#         action = current_action
#         return action.clamp_(-self.max_action, self.max_action) if clip_action else action


# class IMMLoss:

#     def __init__(
#         self,  
#         sigma=1, 
#         sample_t_mode="lognormal",
#         P_mean=-1.1,
#         P_std=2.0, 
#         matrix_size=16, 
#         sample_repeat=1,
#         k=12,
#         a=2,
#         b=4, 
#         min_tr_gap=None,  
#         **kwargs,
#     ):  
#         super().__init__()

#         self.P_mean = P_mean
#         self.P_std = P_std
#         self.min_tr_gap = min_tr_gap
#         self.sigma = sigma
#         self.sample_t_mode = sample_t_mode 
#         self.matrix_size = matrix_size
#         self.sample_repeat = sample_repeat
#         self.a = a
#         self.b = b
#         self.k = k
            
#     def nt_to_nr(self, nt, net):

#         u = (net.nt_high - net.nt_low) * (1/ 2) ** self.k
#         nr = (nt -  u ).clamp(min=net.nt_low, max=net.nt_high) 
#         return nr
    
#     def kernel_fn(self, x, y, flatten_dim, w):
        
#         loss = (
#                 torch.clamp_min(
#                     ((x - y) ** 2).flatten(flatten_dim).sum(-1)  , 1e-8
#                 )
#             ).sqrt()   / (np.prod(y.shape[flatten_dim:])) / self.sigma
            
            
#         ret = torch.exp(-loss * w) 
#         return ret 
#     def kernel(
#         self,
#         x,
#         y,   
#         w=None, 
#     ):

#         # x: (t, b, ...)
#         # y: (t, b, ...)

#         x = x.unsqueeze(2)  # (t, b, 1, ...)
#         y = y.unsqueeze(1)  # (t, 1, b, ...)
#         if w is None:
#             w = 1
#         else:
#             w = w[:, None, None]
  
            
#         ret = self.kernel_fn(x, y,  flatten_dim=3,  w=w )

#         return ret

#     def sample_trs(self, t_bs, net, device ):

#         high = net.nt_high 
#         low = net.nt_low 
#         nt, log_nt = net.sample_eta_t(
#             t_bs,
#             device,
#             log_low=np.log(low),
#             log_high=np.log(high),
#             low=low,
#             high=high,
#             sample_mode=self.sample_t_mode,
#             P_mean=self.P_mean,
#             P_std=self.P_std,
#         ) 
 
#         ns_upper = nt
#         logns_upper = log_nt
          
#         ns, log_ns = net.sample_eta_t(
#             t_bs,
#             device,
#             log_low=np.log(net.nt_low),
#             log_high=logns_upper,
#             low=net.nt_low,
#             high=ns_upper,
#             sample_mode=self.sample_t_mode,
#             P_mean=self.P_mean,
#             P_std=self.P_std,
#         ) 
#         ns = torch.minimum(ns, nt).clamp(min=net.nt_low)
                
        
#         nr = self.nt_to_nr(nt, net) 
                   
            
#         t = net.nt_to_t(nt) 
#         r = net.nt_to_t(nr)
#         s = net.nt_to_t(ns)
            
#         assert torch.allclose(net.get_log_nt(t).exp(), nt)
#         assert torch.allclose(net.get_log_nt(r).exp(), nr)
#         assert torch.allclose(net.get_log_nt(s).exp(), ns)
 
#         if self.min_tr_gap is not None: 
#             max_r = torch.clamp(t - self.min_tr_gap, min=net.nt_low,  )
         
#             r  = torch.minimum(r , max_r ) 
#         # makes sure s<r<t  
#         r  = torch.maximum(r , s ).clamp(min=net.eps)  
#         return t, r, s
      
     
#     def get_loss(self,  f_st, f_sr,   w, wout,    ):
          
#         # MMD 
#         inter_sample_sim = self.kernel(
#             f_st ,
#             f_st , 
#             w=w,
#         ) 
            

#         inter_gt_sim = self.kernel(
#             f_sr ,
#             f_sr , 
#             w=w,
#         )  
        
            
#         cross_sim = self.kernel(
#             f_st ,
#             f_sr , 
#             w=w,
#         ) 
             
#         inter_sample_sim =inter_sample_sim.mean((1, 2)) 
#         cross_sim = cross_sim.mean((1, 2))
#         inter_gt_sim = inter_gt_sim.mean((1, 2))
            
#         loss = inter_sample_sim + inter_gt_sim - 2 * cross_sim 
            
#         if wout is not None:
#             loss = wout * loss
               
#         logs = {
#             "inter_sample_sim": inter_sample_sim.detach(),
#             "inter_gt_sim": inter_gt_sim.detach(),
#             "cross_sim": cross_sim.detach(),
#         }
#         return loss.mean(), logs
     
    
#     def __call__(
#         self,
#         net,
#         action,
#         state,
#         device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), 
#         **kwargs
#     ): 
        
#         state = state.to(device)
#         current_matrix_size = self.matrix_size 
#         # t ~ p(t) and r ~ p(r|t, iters) (Mapping fn)

#         t, r, s = self.sample_trs(state.shape[0]  // current_matrix_size, net, device=state.device,)
#         t = t.repeat_interleave(current_matrix_size, dim=0)
#         s = s.repeat_interleave(current_matrix_size, dim=0)
#         r = r.repeat_interleave(current_matrix_size, dim=0)   
        
#         action = action.repeat_interleave(self.sample_repeat, dim=0)
        
#         t_flat = t.flatten()
#         s_flat = s.flatten()
#         r_flat = r.flatten()
        
#         yt, noise_t = net.add_noise(action, t_flat.unsqueeze(1))
          
 
#         yr = net.ddim(yt, action, t_flat.unsqueeze(1), r_flat.unsqueeze(1), noise= noise_t)
                 

#         # Shared Dropout Mask
#         rng_state = torch.cuda.get_rng_state()
         

#         f_st = net(
#             yt,
#             t_flat.unsqueeze(1),
#             s_flat.unsqueeze(1),
#             state=state
#         ) 
         
        
#         torch.cuda.set_rng_state(rng_state) 
        
#         with torch.no_grad() :   
#             f_sr = net(
#                 yr,
#                 r_flat.unsqueeze(1),
#                 s_flat.unsqueeze(1),
#                 state=state,
#             ) 

#         f_st = rearrange(f_st, "(b m) ... -> b m ...", m=current_matrix_size)
#         f_sr = rearrange(f_sr, "(b m) ... -> b m ...", m=current_matrix_size) 
#         yt = rearrange(yt, "(b m) ... -> b m ...", m=current_matrix_size)
#         yr = rearrange(yr, "(b m) ... -> b m ...", m=current_matrix_size)
#         t = rearrange(t, "(b m) ... -> b m ...", m=current_matrix_size)
#         r = rearrange(r, "(b m) ... -> b m ...", m=current_matrix_size) 
#         s = rearrange(s, "(b m) ... -> b m ...", m=current_matrix_size)
#         wt, wtout = net.get_kernel_weight(
#             t[:, 0].flatten(), 
#             s[:, 0].flatten(),   
#             a=self.a, 
#             b=self.b, 
#         )    
#         loss, loss_logs = self.get_loss(f_st, f_sr , wt, wtout)
#         logs = {
#             "r_t_ratio": r[:, 0] / t[:, 0], 
#             "s_t_ratio": s[:, 0] / t[:, 0],
#             "t_r_diff": t[:, 0] - r[:, 0],
#             'loss': loss,
#             **loss_logs
#         } 

#         if torch.isnan(loss).any():
#             print("Nan in loss")
#             loss = torch.nan_to_num(loss) 
 
#         logs["ts"] = t[:, 0].flatten()  
                   
#         return loss, logs