"""
Improved Stochastic Bridge Matching (SBM) Implementation

This module implements Stochastic Bridge Matching based on principles from 
Schrödinger Bridge theory and Conditional Flow Matching.
"""

import torch
import torch.nn as nn
import numpy as np
import warnings
from typing import Union, Optional, Tuple, Callable
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x


class OTPlanSampler:
    """
    Optimal Transport Plan Sampler for stochastic coupling of samples
    from distributions π₀ and π₁.
    """
    def __init__(self, method: str = "exact", reg: float = 1.0):
        """
        Initialize the OT Plan Sampler.
        
        Args:
            method: Method for OT computation ('exact' or 'sinkhorn')
            reg: Regularization parameter for entropic OT
        """
        self.method = method
        self.reg = reg
        
        # Import ot library only when needed to avoid dependency issues
        try:
            import ot
            self.pot = ot
        except ImportError:
            # Fallback to a simple implementation without OT library
            self.pot = None
            print("Warning: Could not import POT (Python Optimal Transport).")
            print("Using a simple permutation for OT coupling.")
    
    def compute_plan(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        """
        Compute the OT plan between two batches of samples.
        
        Args:
            x0: Source samples (batch_size, dim)
            x1: Target samples (batch_size, dim)
            
        Returns:
            OT coupling matrix
        """
        # If POT is not available, use a simple permutation
        if self.pot is None:
            n = x0.shape[0]
            P = torch.zeros((n, n))
            perm = torch.randperm(n)
            for i in range(n):
                P[i, perm[i]] = 1.0
            return P.to(x0.device)
            
        # Move to CPU for POT computations
        x0_np = x0.detach().cpu().numpy()
        x1_np = x1.detach().cpu().numpy()
        
        # Uniform weights for points
        n = x0.shape[0]
        a = np.ones(n) / n
        b = np.ones(n) / n
        
        try:
            # Compute cost matrix (squared Euclidean distance)
            M = self.pot.dist(x0_np, x1_np, metric='sqeuclidean')
            
            # Compute OT plan
            if self.method == "exact":
                P = self.pot.emd(a, b, M)
            elif self.method == "sinkhorn":
                P = self.pot.sinkhorn(a, b, M, self.reg)
            else:
                raise ValueError(f"Unknown OT method: {self.method}")
                
            return torch.from_numpy(P).float().to(x0.device)
        except Exception as e:
            print(f"Warning: Error computing OT plan: {e}")
            print("Falling back to random permutation.")
            # Fallback to random permutation
            n = x0.shape[0]
            P = torch.zeros((n, n))
            perm = torch.randperm(n)
            for i in range(n):
                P[i, perm[i]] = 1.0
            return P.to(x0.device)
    
    def sample_plan(self, x0: torch.Tensor, x1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Sample pairs of points according to the OT plan.
        
        Args:
            x0: Source samples (batch_size, dim)
            x1: Target samples (batch_size, dim)
            
        Returns:
            Tuple of source and target samples reordered according to OT plan
        """
        batch_size = x0.shape[0]
        
        # If batch size is 1, just return the points (no coupling needed)
        if batch_size == 1:
            return x0, x1
            
        try:
            P = self.compute_plan(x0, x1)
            
            # Sample according to joint distribution defined by P
            P_flat = P.flatten()
            
            # Ensure P_flat is valid for multinomial sampling (no zeros or negatives)
            if torch.any(P_flat <= 0):
                P_flat = torch.where(P_flat <= 0, torch.ones_like(P_flat) * 1e-10, P_flat)
                P_flat = P_flat / P_flat.sum()  # Normalize
                
            indices = torch.multinomial(P_flat, batch_size, replacement=True)
            
            # Convert flat indices to 2D indices
            row_indices = indices // batch_size
            col_indices = indices % batch_size
            
            # Select samples according to indices
            x0_sampled = x0[row_indices]
            x1_sampled = x1[col_indices]
            
            return x0_sampled, x1_sampled
        except Exception as e:
            print(f"Warning: Error sampling from OT plan: {e}")
            print("Falling back to random permutation.")
            # Fallback to random permutation
            perm = torch.randperm(batch_size)
            return x0, x1[perm]


class ImprovedSBM(ConditionalFlowMatcher):
    """
    Improved Stochastic Bridge Matching (SBM) method.
    
    This class implements the SB-CFM (Schrödinger Bridge Conditional Flow Matching)
    approach, which combines principles from Conditional Flow Matching and
    Schrödinger Bridge theory.
    """
    
    def __init__(
        self, 
        sigma: Union[float, int, Callable] = 1.0, 
        ot_method: str = "exact",
        noise_schedule: str = "linear",
        ot_reg: Optional[float] = None,
        method: str = "sb"
    ):
        """
        Initialize the Improved SBM method.
        
        Args:
            sigma: Noise scale parameter (or noise schedule function)
            ot_method: Method for OT computation ('exact' or 'sinkhorn')
            noise_schedule: The noise schedule to use ('linear', 'cosine', or 'sigmoid')
            ot_reg: Regularization parameter for entropic OT (defaults to 2*sigma^2)
        """
        super().__init__(sigma=sigma)
        
        # Validate parameters
        if isinstance(sigma, (int, float)) and sigma <= 0:
            raise ValueError(f"Sigma must be strictly positive, got {sigma}.")
        elif isinstance(sigma, (int, float)) and sigma < 1e-3:
            warnings.warn("Small sigma values may lead to numerical instability.")
            
        self.noise_schedule = noise_schedule
        self.ot_method = ot_method
        self.method = method
        
        # Set OT regularization
        if ot_reg is None:
            if isinstance(sigma, (int, float)):
                ot_reg = 2 * self.sigma**2
            else:
                # Default regularization if sigma is a function
                ot_reg = 1.0
                
        # Initialize OT plan sampler
        self.ot_sampler = OTPlanSampler(method=ot_method, reg=ot_reg)
    
    def compute_sigma_t(self, t: torch.Tensor) -> torch.Tensor:
        """
        Compute the noise scale sigma(t) at time t.
        
        Args:
            t: Time parameter(s) in [0, 1]
            
        Returns:
            Noise scale sigma(t)
        """
        # For backward compatibility, handle scalar sigma
        sigma = self.sigma
        if callable(sigma):
            try:
                sigma = sigma(t)  # Try to call it as a function
            except:
                # If it fails, use sigma as a constant
                pass
                
        if self.noise_schedule == "linear":
            # Linear noise schedule: sigma * sqrt(2 * t * (1 - t))
            return torch.sqrt(sigma * 2 * t * (1 - t))
        elif self.noise_schedule == "cosine":
            # Cosine noise schedule (smoother)
            s = 0.008  # Small offset for numerical stability
            ft = torch.cos((t + s) / (1 + s) * 0.5 * np.pi) ** 2
            ft_0 = torch.cos((torch.zeros_like(t) + s) / (1 + s) * 0.5 * np.pi) ** 2
            ft_1 = torch.cos((torch.ones_like(t) + s) / (1 + s) * 0.5 * np.pi) ** 2
            
            # Normalize to ensure it's 0 at t=0 and t=1
            alpha_t = (ft - ft_1) / (ft_0 - ft_1)
            sigma_t = sigma * torch.sqrt(1 - alpha_t**2)
            return sigma_t
        elif self.noise_schedule == "sigmoid":
            # Sigmoid-based noise schedule
            return sigma * 2 * torch.sigmoid(10 * (t - 0.5))
        else:
            raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")

    def compute_conditional_flow(
        self, 
        x0: torch.Tensor, 
        x1: torch.Tensor, 
        t: torch.Tensor, 
        xt: torch.Tensor, 
        direction: str = "forward"
    ) -> torch.Tensor:
        """
        Compute the conditional flow field ut at (xt, t).
        
        Args:
            x0: Initial samples
            x1: Terminal samples
            t: Time parameter(s)
            xt: Current state at time t
            direction: Direction of the flow ('forward' or 'backward')
            
        Returns:
            The conditional flow field ut at (xt, t)
        """
        t = pad_t_like_x(t, x0)
        
        # Compute drift (deterministic part)
        if direction == "forward":
            ut = (x1 - xt) / (1 - t + 1e-5)  # Added epsilon for numerical stability
        else:
            ut = (x0 - xt) / (t + 1e-5)  # Added epsilon for numerical stability
           
        return ut
    
    def sample_location_and_conditional_flow(
        self, 
        x0: torch.Tensor, 
        x1: torch.Tensor, 
        t: Optional[torch.Tensor] = None,
        return_noise: bool = False,
        direction: str = "forward"
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Sample a location xt and corresponding conditional flow ut.
        
        Args:
            x0: Initial samples
            x1: Terminal samples
            t: Time parameter(s) (randomly sampled if None)
            return_noise: Whether to return the noise samples
            direction: Direction of the flow ('forward' or 'backward')
            
        Returns:
            t: Time parameter(s)
            xt: Sampled location at time t
            ut: Conditional flow at (xt, t)
            eps: Noise samples (if return_noise=True)
        """
        # Sample from the OT plan
        x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        
        # Sample time if not provided
        if t is None:
            t = torch.rand(x0.shape[0]).type_as(x0)
        
        assert len(t) == x0.shape[0], "t has to have batch size dimension"
        
        # Sample noise
        eps = torch.randn_like(x0)
        
        # Sample location at time t
        xt = self.sample_xt(x0, x1, t, eps, direction)
        
        # Compute conditional flow
        ut = self.compute_conditional_flow(x0, x1, t, xt, direction)
        
        if return_noise:
            return t, xt, ut, eps
        else:
            return t, xt, ut
    
    def sample_xt(
        self, 
        x0: torch.Tensor, 
        x1: torch.Tensor, 
        t: torch.Tensor, 
        epsilon: torch.Tensor,
        direction: str = "forward"
    ) -> torch.Tensor:
        """
        Sample from the marginal distribution p(xt|x0,x1).
        
        Args:
            x0: Initial samples
            x1: Terminal samples
            t: Time parameter(s)
            epsilon: Standard normal noise
            direction: Direction of the flow ('forward' or 'backward')
            
        Returns:
            Samples from the marginal distribution at time t
        """
        t = pad_t_like_x(t, x0)
        
        # Compute deterministic part (mean)
        if direction == "forward":
            mu_t = (1 - t) * x0 + t * x1
        else:
            mu_t = t * x0 + (1 - t) * x1
            
        # Add stochastic part
        sigma_t = self.compute_sigma_t(t)
        sigma_t = pad_t_like_x(sigma_t, x0)
        
        return mu_t + sigma_t * epsilon
        
    def sample_trajectory(
        self, 
        x0: torch.Tensor, 
        model: nn.Module, 
        steps: int = 100, 
        eps: float = 1e-3, 
        solver: str = "euler",
        return_intermediate: bool = True
    ) -> torch.Tensor:
        """
        Sample a trajectory from x0 to x1 using the learned flow.
        
        Args:
            x0: Initial samples
            model: Neural network model for the vector field
            steps: Number of integration steps
            eps: Small time offset from boundaries for numerical stability
            solver: ODE solver ('euler', 'heun', or 'rk4')
            return_intermediate: Whether to return the full trajectory
            
        Returns:
            Samples of the trajectory (or just final state if return_intermediate=False)
        """
        device = x0.device
        batch_size = x0.shape[0]
        
        # Setup time grid with small offset from boundary
        t_span = torch.linspace(eps, 1.0 - eps, steps).to(device)
        
        # Initialize trajectory storage if needed
        if return_intermediate:
            trajectory = [x0.detach()]
            
        # Current state
        x = x0.detach()
        
        # Integration loop
        dt = 1.0 / (steps - 1)
        
        for i in range(steps - 1):
            t = t_span[i] * torch.ones(batch_size, dtype=x.dtype, device=device)
            
            # Get vector field from model
            v = model(t, x)
            
            # Apply different numerical solvers
            if solver == "euler":
                # Euler method (first-order)
                x = x + dt * v
            elif solver == "heun":
                # Heun's method (second-order)
                x_euler = x + dt * v
                t_next = t_span[i + 1] * torch.ones_like(t)
                v_next = model(t_next, x_euler)
                x = x + 0.5 * dt * (v + v_next)
            elif solver == "rk4":
                # 4th-order Runge-Kutta
                k1 = v
                t_half = (t_span[i] + t_span[i + 1]) / 2 * torch.ones_like(t)
                k2 = model(t_half, x + dt * k1 / 2)
                k3 = model(t_half, x + dt * k2 / 2)
                t_next = t_span[i + 1] * torch.ones_like(t)
                k4 = model(t_next, x + dt * k3)
                x = x + dt * (k1 + 2*k2 + 2*k3 + k4) / 6
            else:
                raise ValueError(f"Unknown solver: {solver}")
                
            if return_intermediate:
                trajectory.append(x.detach())
        
        if return_intermediate:
            return torch.stack(trajectory, dim=0)
        else:
            return x


# Utility for creating a hybrid model that combines velocity and score networks
class HybridSBModel(nn.Module):
    """
    Hybrid Schrödinger Bridge model that combines a velocity field and a score function.
    
    This provides an improved SDE framework following principles from the
    Diffusion Schrödinger Bridge Matching approach.
    """
    
    def __init__(
        self,
        velocity_net: nn.Module,
        score_net: Optional[nn.Module] = None,
        sigma: float = 1.0,
        alpha: float = 0.5
    ):
        """
        Initialize the hybrid model.
        
        Args:
            velocity_net: Network for the velocity field
            score_net: Network for the score function (optional)
            sigma: Noise scale parameter
            alpha: Mixing coefficient between the velocity and score components
        """
        super().__init__()
        self.velocity_net = velocity_net
        self.score_net = score_net
        self.sigma = sigma
        self.alpha = alpha
        
    def forward(self, t, x):
        """
        Compute the combined vector field.
        
        The vector field is a combination of:
        1. The velocity field (deterministic component)
        2. The score field (corrects the noise, if provided)
        
        Args:
            t: Time parameter(s)
            x: State at time t
            
        Returns:
            Tuple of (velocity, score) if self.score_net exists
            Otherwise just returns velocity
        """
        # Get the velocity component from the velocity network
        v = self.velocity_net(t, x)
        
        # If we have a score network, compute it and return both components
        if self.score_net is not None:
            # Compute the score component
            s = self.score_net(t, x)
            return v, s
        else:
            return v
            

# Utility for deriving a drift model from a hybrid model
class SBDriftModel(nn.Module):
    """
    Drift model derived from a hybrid SB model.
    
    This model combines the velocity and score components into a single
    drift term for use in ODE-based sampling.
    """
    
    def __init__(
        self,
        v: Callable,
        s: Optional[Callable] = None,
        sigma: float = 1.0,
        alpha: float = 0.5
    ):
        """
        Initialize the drift model.
        
        Args:
            v: Velocity field function
            s: Score function (optional)
            sigma: Noise scale parameter
            alpha: Mixing coefficient between the velocity and score components
        """
        super().__init__()
        self.v = v
        self.s = s
        self.sigma = sigma
        self.alpha = alpha
        
    def forward(self, t, x):
        """
        Compute the drift term.
        
        Args:
            t: Time parameter(s)
            x: State at time t
            
        Returns:
            The combined drift term
        """
        # Compute the velocity component
        velocity = self.v(t, x)
        
        # If we have a score function, add the score-based correction
        if self.s is not None:
            # Compute diffusion coefficient based on time
            if isinstance(t, torch.Tensor) and len(t.shape) > 0:
                t_for_sigma = t[0].item()  # Use first batch element
            else:
                t_for_sigma = t
                
            # Simple linear schedule for sigma(t)
            sigma_t = self.sigma * np.sqrt(2 * t_for_sigma * (1 - t_for_sigma))
            
            # Compute the score component
            score = self.s(t, x)
            
            # Combine the components
            drift = self.alpha * velocity + (1 - self.alpha) * (sigma_t**2) * score
            return drift
        else:
            return velocity
