from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
import abc
import torch 
import torchdyn
from torchdyn.core import NeuralODE
import numpy as np

import torch
import torch.nn.functional as F
from torch.distributions import Categorical, MixtureSameFamily, Independent, Normal



def diagonal_collapsing_cov(n: int, k: int, a: float = 0.1) -> torch.Tensor:
    if 2*k > n:
        raise ValueError(f"Cannot collapse {k} dims in total dimension {n} (2*k must be ≤ n).")

    # Build the diagonal eigenvalues
    vals = torch.ones(n)
    vals[:k] = a
    vals[k:2*k] = 1.0 / a

    # Form the diagonal covariance matrix
    cov = torch.diag(vals)
    return cov

def make_standard_gmm(ncomp: int, dim: int, seed: int = None):
    """
    Build a GMM with ncomp components in R^dim, then standardize:
      - Overall mean = 0
      - Overall covariance = I (unit variance per axis)

    Returns:
      gmm: a torch.distributions.MixtureSameFamily instance
      params: dict with 'logits', 'locs', 'scales' for further reuse
    """
    if seed is not None:
        torch.manual_seed(seed)

    # 1) Initialize uniform mixture logits → equal weights after softmax
    logits = torch.zeros(ncomp)

    # 2) Random initial locs & positive scales
    locs   = torch.randn(ncomp, dim)
    scales = F.softplus(torch.randn(ncomp, dim))

    # 3) Compute mixture weights π_k
    weights = torch.softmax(logits, dim=0)  # shape [ncomp]

    # 4) Compute current mixture mean μ̄ = ∑ π_k μ_k
    mu_bar = (weights.unsqueeze(1) * locs).sum(dim=0)  # [dim]

    # 5) Center locs: μ_k ← μ_k – μ̄
    locs_centered = locs - mu_bar

    # 6) Compute between-component variance: ∑ π_k ‖μ_k−μ̄‖²
    var_between = (weights * (locs_centered.pow(2).sum(dim=1))).sum()

    # 7) Solve for σ² so total variance = 1:
    #      trace(Σ) = d*σ² + ∑ π_k‖μ_k−μ̄‖² = d  ⇒  σ² = (d - var_between) / d
    sigma2 = torch.clamp((dim - var_between) / dim, min=1e-6)
    sigma  = sigma2.sqrt()

    # 8) Set all component scales = σ
    scales_standard = sigma.expand(ncomp, dim)

    # 9) Build standardized GMM
    mix  = Categorical(logits=logits)
    comp = Independent(Normal(loc=locs_centered, scale=scales_standard), 1)
    gmm  = MixtureSameFamily(mix, comp)

    return gmm

def _generate_cov(dim, seed=42):
    """Generate a consistent covariance matrix for a given dimension"""
    # Set seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Generate a random positive definite matrix
    A = torch.randn(dim, dim)
    cov = A @ A.T  # Positive semi-definite
    
    # Add small value to diagonal for numerical stability
    cov += 1e-3 * torch.eye(dim)
    
    # Make it more interpretable by scaling
    cov = cov / torch.norm(cov) * dim
    
    return cov

def rotation_90_degrees(dim=2):
    rotation_matrix = torch.eye(dim)
    rotation_matrix[0, 1] = -1
    rotation_matrix[1, 0] = 1
    return rotation_matrix
    

new_cov = _generate_cov(2)
new_cov[0, 1] = - new_cov[0, 1]
new_cov[1, 0] = - new_cov[1, 0]
new_cov = torch.eye(2)/10

class SourceDistribution(abc.ABC):
    # Predefined means and covariances for different dimensions
    # These will be used consistently across experiments
    PREDEFINED_PARAMS = {
        # For each dimension, store a tuple of (mean, covariance)
        2: (torch.ones(2) * 5, new_cov), # (torch.ones(2) * 5, _generate_cov(2)),
        10: (torch.ones(10) * 5, _generate_cov(10)),
        30: (torch.ones(30) * 5, _generate_cov(30)),
        100: (torch.ones(100) * 5, _generate_cov(100))
    }
    
    def __init__(self, args):
        # Get dimension from args
        dim = args.dim
        
        if args.source_type == "normal":
            # Use predefined parameters if available for this dimension
            if dim in self.PREDEFINED_PARAMS:
                mean, cov = self.PREDEFINED_PARAMS[dim]
                self.distribution = MultivariateNormal(
                    loc=mean,
                    covariance_matrix=cov    )     
        elif args.source_type == "gmm":
            self.distribution = make_standard_gmm(ncomp = args.components_source, dim =args.dim)
        else:
            # For other types, fallback to standard normal
            self.distribution = MultivariateNormal(
                loc=torch.zeros(dim),
                covariance_matrix=torch.eye(dim)
            )

        self.test_values = self.distribution.sample((32,))
        self.x0 = self.distribution.sample((12800,))


class TargetDistribution(abc.ABC):
    # Predefined means and covariances for different dimensions
    # These will be used consistently across experiments

    PREDEFINED_PARAMS = {
        # For each dimension, store a tuple of (mean, covariance)
        # 2: (torch.ones(2) * 5, _generate_cov(2)),
        2: (torch.ones(2) * 5, new_cov), # (torch.ones(2) * 5, _generate_cov(2)),

        10: (torch.ones(10) * 5, _generate_cov(10)),
        30: (torch.ones(30) * 5, _generate_cov(30)),
        100: (torch.ones(100) * 5, _generate_cov(100))
    }
    
    def __init__(self, args):
        # Get dimension from args
        dim = args.dim
        if  args.target_type == "gmm":
            self.distribution = make_standard_gmm(ncomp = args.components_target, dim =args.dim)

        elif  args.target_type == "normal":
            # Use predefined parameters if available for this dimension
            if dim in self.PREDEFINED_PARAMS:
                mean, cov = self.PREDEFINED_PARAMS[dim]
                self.distribution = MultivariateNormal(
                    loc=mean,
                    covariance_matrix=cov
                )
        elif args.target_type == "collapsing":
            self.distribution = MultivariateNormal(
                loc = torch.zeros(dim),
                covarinace_matrix = torch.eye(dim) 
            )
        else:
            # Default target distribution
            self.distribution = MultivariateNormal(
                loc=torch.zeros(dim),
                covariance_matrix=torch.eye(dim)
            )
    
        
        self.x1 = self.distribution.sample((12800,)) 
        self.test_values = self.distribution.sample((32,))



class Synthetic(abc.ABC):
    def __init__(self, args):

        self.source_distr = SourceDistribution(args)
        self.target_distr = TargetDistribution(args)
        
        self.original_x0 = self.source_distr.x0.clone()
        self.original_x1 = self.target_distr.x1.clone()
        
        self.all_pairs = torch.utils.data.TensorDataset(self.original_x0, self.original_x1)
        
        self.all_pairs_forward = None
        self.all_pairs_backward = None
        
        self.reflow_nll = None
        self.current_epoch = 0
        self.args = args


    def w2_distance(self, generated_x1):
        """Calculate approximate W2 distance between generated samples and target distribution.
        
        For Gaussian targets, the W2 distance simplifies to the formula below.
        See https://en.wikipedia.org/wiki/Wasserstein_metric for details.
        
        Args:
            generated_x1: Generated target samples
            
        Returns:
            Wasserstein-2 distance between generated and true target distributions
        """
        # Get the target mean and generated mean
        target_mean = self.target_distr.distribution.loc
        generated_mean = torch.mean(generated_x1, dim=0)
        
        # Calculate the mean term of W2: |μ_1 - μ_2|^2
        mean_term = torch.sum((target_mean - generated_mean) ** 2)
        
        # For cov terms, we need sample covariance of generated samples
        n = generated_x1.size(0)
        centered = generated_x1 - generated_mean.unsqueeze(0)  # Center the data
        sample_cov = (centered.T @ centered) / (n - 1)        # Sample covariance matrix
        
        # Get the target covariance
        target_cov = self.target_distr.distribution.covariance_matrix
        # Calculate the trace term using the Frobenius norm
        # trace(Σ_1 + Σ_2 - 2(Σ_1^(1/2) Σ_2 Σ_1^(1/2))^(1/2))
        # For simplicity, we use: trace(Σ_1) + trace(Σ_2) - 2 * trace(Σ_1^(1/2) Σ_2 Σ_1^(1/2))^(1/2)
        # This is an approximation, but works well for our purpose
        cov_term = torch.trace(target_cov) + torch.trace(sample_cov) - 2 * torch.trace((target_cov + sample_cov) / 2)
        
        return mean_term + cov_term
        

    def forward(self, model, args):
        """Simulate forward process (source→target) starting from original x0 samples"""
        torch.cuda.empty_cache()
        model_cpu = model.cpu()
        
        batch_size = 500
        # Always start with the original source samples
        x0_full = self.original_x0.cpu()
        num_samples = x0_full.shape[0]
        x1_full = torch.zeros_like(x0_full)
        
        node = NeuralODE(
            model_cpu,
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-3,
            rtol=1e-3
        )
        
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            x0_batch = x0_full[i:end_idx]
            
            if args.method == "gmlp" or args.method == "emlp":
                traj = node.trajectory(
                    x0_batch,
                    t_span=torch.linspace(0, 1, 50)
                )
            else:
                with torch.no_grad():
                    traj = node.trajectory(
                        x0_batch,
                        t_span=torch.linspace(0, 1, 50)
                    )
            
            x1_full[i:end_idx] = traj[-1].detach().clone()
            del traj, x0_batch
            torch.cuda.empty_cache()
        
        # Create a dataset with original x0 and integrated x1
        # This is a deterministic mapping x0 -> integrated(x0)
        self.all_pairs_forward = torch.utils.data.TensorDataset(x0_full.detach(), x1_full.detach())

        # Store the forward integrated test values for evaluation
        self.target_distr.test_values_forward = x1_full[:32].detach().clone()
        
        del node, model_cpu
        torch.cuda.empty_cache()

    def forward_source(self, model, args):
        """Simulate forward process (source→target) starting from original x0 samples"""
        torch.cuda.empty_cache()    

        model_cpu = model.cpu()
        
        batch_size = 500
        # Always start with the original source samples
        x0_full = self.original_x0.cpu()
        num_samples = x0_full.shape[0]
        x1_full = torch.zeros_like(x0_full)
        
        node = NeuralODE(
            model_cpu,
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-3,
            rtol=1e-3
        )
        
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            x0_batch = x0_full[i:end_idx]
            
            if args.method == "gmlp" or args.method == "emlp":
                traj = node.trajectory(
                    x0_batch,
                    t_span=torch.linspace(0, 1, 50)
                )
            else:
                with torch.no_grad():
                    traj = node.trajectory(
                        x0_batch,
                        t_span=torch.linspace(0, 1, 50)
                    )
            
            x1_full[i:end_idx] = traj[-1].detach().clone()
            del traj, x0_batch
            torch.cuda.empty_cache()
        
        # Create a dataset with original x0 and integrated x1
        # This is a deterministic mapping x0 -> integrated(x0)
        self.all_pairs_forward_source = torch.utils.data.TensorDataset(x0_full.detach(), x1_full.detach())
        
        # Store the forward integrated test values for evaluation
        self.target_distr.test_values_forward_source = x1_full[:32].detach().clone()
        
        del node, model_cpu
        torch.cuda.empty_cache()

    def forward_target(self, model, args):
        """Simulate forward process (target-source) starting from original x0 samples"""
        torch.cuda.empty_cache()    

        model_cpu = model.cpu()
        
        batch_size = 500
        # Always start with the original source samples
        x0_full = self.original_x0.cpu()
        num_samples = x0_full.shape[0]
        x1_full = torch.zeros_like(x0_full)
        
        node = NeuralODE(
            model_cpu,
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-3,
            rtol=1e-3
        )
        
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            x0_batch = x0_full[i:end_idx]
            
            if args.method == "gmlp" or args.method == "emlp":
                traj = node.trajectory(
                    x0_batch,
                    t_span=torch.linspace(0, 1, 50)
                )
            else:
                with torch.no_grad():
                    traj = node.trajectory(
                        x0_batch,
                        t_span=torch.linspace(0, 1, 50)
                    )
            
            x1_full[i:end_idx] = traj[-1].detach().clone()
            del traj, x0_batch
            torch.cuda.empty_cache()
        
        # Create a dataset with original x0 and integrated x1
        # This is a deterministic mapping x0 -> integrated(x0)
        self.all_pairs_forward_target = torch.utils.data.TensorDataset(x0_full.detach(), x1_full.detach())
        
        # Store the forward integrated test values for evaluation
        self.target_distr.test_values_forward_target = x1_full[:32].detach().clone()
        
        del node, model_cpu
        torch.cuda.empty_cache()


    def backward(self, model, args):
        """Simulate backward process (target→source) starting from original x1 samples"""
        torch.cuda.empty_cache()    

        model_cpu = model.cpu()
        
        batch_size = 500
        # Always start with the original target samples
        x1_full = self.original_x1.cpu()
        num_samples = x1_full.shape[0]
        x0_full = torch.zeros_like(x1_full)
        
        node = NeuralODE(
            model_cpu,
            solver="dopri5",
            sensitivity="adjoint",
            atol=1e-3,
            rtol=1e-3
        )
        
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            x1_batch = x1_full[i:end_idx]
            
            if args.method == "gmlp" or args.method == "emlp":
                traj = node.trajectory(
                    x1_batch,
                    t_span=torch.linspace(1, 0, 50)
                )
            else:
                with torch.no_grad():
                    traj = node.trajectory(
                        x1_batch,
                        t_span=torch.linspace(1, 0, 50)
                    )
            
            x0_full[i:end_idx] = traj[-1].detach().clone()
            del traj, x1_batch
            torch.cuda.empty_cache()

        # Create a dataset with integrated x0 and original x1 
        # This is a deterministic mapping backwardintegrated(x1) -> x1
        self.all_pairs_backward = torch.utils.data.TensorDataset(x0_full.detach(), x1_full.detach())

        # Store the backward integrated test values for evaluation
        self.source_distr.test_values_backward = x0_full[:32].detach().clone()
        
        del node, model_cpu
        torch.cuda.empty_cache()    

    def update_pairs(self):
        """Generate new random pairs for training the initial model (k=0)"""
        # Generate fresh random samples from the source distribution
        new_x0 = self.source_distr.distribution.sample((self.original_x1.shape[0],))
        # Create a dataset with these new random samples and the original target samples
        self.all_pairs = torch.utils.data.TensorDataset(new_x0, self.original_x1)

