"""
Linear Transform Trainer Module
Contains orthogonal linear transform models with different orthogonality enforcement methods.
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
import itertools
import torch.nn.functional as F
import time
import math
import random
from tqdm import tqdm

# =============================================================================
# ORTHOGONAL LINEAR TRANSFORM MODELS
# =============================================================================

class OrthogonalLinearTransform(nn.Module):
    """Base class for orthogonal linear transforms."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        
    def forward(self, x):
        raise NotImplementedError

class SVDProjectionTransform(OrthogonalLinearTransform):
    """Linear transform with SVD projection to enforce orthogonality."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For SVD projection, input_dim must equal output_dim")
        # Initialize as orthogonal matrix
        A = torch.randn(self.output_dim, self.input_dim)
        try:
            U, _, Vt = torch.linalg.svd(A, full_matrices=False)
            self.A = nn.Parameter(U @ Vt)
        except torch._C._LinAlgError:
            # Fallback: use QR decomposition if SVD fails
            Q, R = torch.linalg.qr(A)
            self.A = nn.Parameter(Q)
        
    def forward(self, x):
        return x @ self.A.T
    
    def project_to_orthogonal(self):
        """Project A to nearest orthogonal matrix using SVD."""
        with torch.no_grad():
            try:
                U, _, Vt = torch.linalg.svd(self.A, full_matrices=False)
                self.A.data = U @ Vt
            except torch._C._LinAlgError:
                # Fallback: use QR decomposition if SVD fails
                Q, R = torch.linalg.qr(self.A)
                self.A.data = Q

class PenaltyTransform(OrthogonalLinearTransform):
    """Linear transform with orthogonality penalty."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For orthogonality penalty, input_dim must equal output_dim")
        self.A = nn.Parameter(torch.randn(self.output_dim, self.input_dim))
        
    def forward(self, x):
        return x @ self.A.T
    
    def orthogonality_penalty(self):
        """Compute orthogonality penalty ||A^T A - I||_F^2 for column orthogonality."""
        ATA = self.A.T @ self.A  # A^T A for column orthogonality
        I = torch.eye(self.input_dim, device=self.A.device, dtype=self.A.dtype)
        return torch.norm(ATA - I, p='fro') ** 2

class ManifoldTransform(OrthogonalLinearTransform):
    """Linear transform using manifold optimization (Stiefel manifold)."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For manifold optimization, input_dim must equal output_dim")
        # Initialize as orthogonal matrix
        A = torch.randn(self.output_dim, self.input_dim)
        try:
            U, _, Vt = torch.linalg.svd(A, full_matrices=False)
            self.A = nn.Parameter(U @ Vt)
        except torch._C._LinAlgError:
            # Fallback: use QR decomposition if SVD fails
            Q, R = torch.linalg.qr(A)
            self.A = nn.Parameter(Q)
        
    def forward(self, x):
        return x @ self.A.T
    
    def project_gradient_to_tangent(self, grad_A):
        """Project gradient to tangent space of Stiefel manifold."""
        # Correct tangent projection: G - A * sym(A^T * G)
        AT_G = self.A.T @ grad_A
        sym_AT_G = 0.5 * (AT_G + AT_G.T)
        return grad_A - self.A @ sym_AT_G
    
    def retract_to_manifold(self):
        """Retract back to manifold using QR decomposition."""
        with torch.no_grad():
            Q, R = torch.linalg.qr(self.A)  # QR of A, not A.T
            self.A.data = Q

class ExponentialMapTransform(OrthogonalLinearTransform):
    """Linear transform using exponential map of skew-symmetric matrix."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For exponential map, input_dim must equal output_dim")
        # Parameterize skew-symmetric matrix S
        self.B = nn.Parameter(torch.randn(self.output_dim, self.input_dim))
        
    def get_skew_symmetric(self):
        """Get skew-symmetric matrix S = (B - B^T) / 2."""
        return 0.5 * (self.B - self.B.T)
    
    def forward(self, x):
        S = self.get_skew_symmetric()
        A = torch.matrix_exp(S)
        return x @ A.T

class CayleyTransform(OrthogonalLinearTransform):
    """Linear transform using Cayley transform of skew-symmetric matrix."""
    def __init__(self, input_dim, output_dim=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For Cayley transform, input_dim must equal output_dim")
        # Parameterize skew-symmetric matrix K
        self.K_param = nn.Parameter(torch.randn(self.output_dim, self.input_dim))
        
    def get_skew_symmetric(self):
        """Get skew-symmetric matrix K = (K_param - K_param^T) / 2."""
        return 0.5 * (self.K_param - self.K_param.T)
    
    def forward(self, x):
        K = self.get_skew_symmetric()
        I = torch.eye(self.output_dim, device=K.device, dtype=K.dtype)
        # A = (I - K)(I + K)^(-1) using solve for numerical stability
        try:
            # Use solve instead of inv for better numerical stability
            A = torch.linalg.solve(I + K, I - K)
        except:
            # Add small regularization if singular
            eps = 1e-6
            A = torch.linalg.solve(I + K + eps * I, I - K)
        return x @ A.T

class GivensRotationTransform(OrthogonalLinearTransform):
    """Linear transform as product of Givens rotations."""
    def __init__(self, input_dim, output_dim=None, num_rotations=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For Givens rotations, input_dim must equal output_dim")
        if num_rotations is None:
            # Use all possible pairs for full expressivity
            num_rotations = self.output_dim * (self.output_dim - 1) // 2
        
        self.num_rotations = num_rotations
        self.angles = nn.Parameter(torch.randn(num_rotations) * 0.1)
        
        # Pre-compute rotation pairs - sample more evenly
        self.rotation_pairs = []
        all_pairs = [(i, j) for i in range(self.output_dim) for j in range(i + 1, self.output_dim)]
        if len(all_pairs) <= num_rotations:
            self.rotation_pairs = all_pairs
        else:
            # Sample evenly across all possible pairs
            step = len(all_pairs) // num_rotations
            self.rotation_pairs = [all_pairs[i * step] for i in range(num_rotations)]
    
    def forward(self, x):
        # Build the full transformation matrix as product of Givens rotations
        # This is autograd-friendly (no in-place operations)
        A = torch.eye(self.output_dim, device=x.device, dtype=x.dtype)
        
        for k, (i, j) in enumerate(self.rotation_pairs):
            angle = self.angles[k]
            cos_a = torch.cos(angle)
            sin_a = torch.sin(angle)
            
            # Create Givens rotation matrix G(i,j,angle)
            G = torch.eye(self.output_dim, device=x.device, dtype=x.dtype)
            G = G.clone()  # Ensure we don't modify the identity matrix
            G[i, i] = cos_a
            G[i, j] = -sin_a
            G[j, i] = sin_a
            G[j, j] = cos_a
            
            A = A @ G
        
        return x @ A.T

class HouseholderTransform(OrthogonalLinearTransform):
    """Linear transform as product of Householder reflections."""
    def __init__(self, input_dim, output_dim=None, num_reflections=None):
        super().__init__(input_dim, output_dim)
        # Ensure square matrix for proper orthogonality
        if self.input_dim != self.output_dim:
            raise ValueError("For Householder reflections, input_dim must equal output_dim")
        if num_reflections is None:
            num_reflections = self.output_dim
        
        self.num_reflections = num_reflections
        # Householder vectors (will be normalized in forward)
        self.householder_vectors = nn.Parameter(torch.randn(num_reflections, self.output_dim))
        
    def forward(self, x):
        # Build the full transformation matrix as product of Householder reflections
        # This is autograd-friendly (no in-place operations)
        A = torch.eye(self.output_dim, device=x.device, dtype=x.dtype)
        
        for k in range(self.num_reflections):
            v = self.householder_vectors[k]
            v = v / (v.norm() + 1e-8)  # Normalize to unit vector
            
            # Create Householder reflection matrix: H = I - 2vv^T
            H = torch.eye(self.output_dim, device=x.device, dtype=x.dtype)
            H = H.clone()  # Ensure we don't modify the identity matrix
            H = H - 2 * torch.outer(v, v)
            A = A @ H
        
        return x @ A.T

# =============================================================================
# LOSS FUNCTIONS
# =============================================================================

class LossFunctions:
    @staticmethod
    def cumulative_energy_shape(v_prime, target_a):
        """Shape loss to encourage exponential energy decay."""
        if v_prime.shape[0] <= 1:
            return torch.tensor(0.0, device=v_prime.device)
        
        dimensions = v_prime.shape[1]
        avg_energy_per_dim = torch.mean(v_prime**2, dim=0)
        total_avg_energy = torch.sum(avg_energy_per_dim)
        
        if total_avg_energy < 1e-9:
            E_k_actual = torch.zeros(dimensions, device=v_prime.device, dtype=v_prime.dtype)
        else:
            E_k_actual = torch.clamp(torch.cumsum(avg_energy_per_dim / total_avg_energy, dim=0), 0.0, 1.0)
        
        k_indices = torch.arange(dimensions, device=v_prime.device, dtype=torch.float32)
        target_a = float(target_a)
        E_k_target = 1.0 - torch.exp(-target_a * (k_indices + 1.0) / dimensions)
        
        loss = torch.mean((E_k_actual - E_k_target)**2)
        
        if torch.isnan(loss):
            return torch.tensor(1e6, device=v_prime.device)
        
        return loss
    
    @staticmethod
    def exponential_shape_loss(v_prime, target_a):
        """Exponential shape loss - direct exponential decay."""
        if v_prime.shape[0] <= 1:
            return torch.tensor(0.0, device=v_prime.device)
        
        dimensions = v_prime.shape[1]
        avg_energy_per_dim = torch.mean(v_prime**2, dim=0)
        avg_energy_per_dim = torch.clamp(avg_energy_per_dim, min=1e-8)
        
        k_indices = torch.arange(dimensions, device=v_prime.device, dtype=torch.float32)
        target_a = float(target_a)
        
        # Target exponential decay: E_k = exp(-a * k)
        E_k_target = torch.exp(-target_a * k_indices)
        
        # Normalize both to sum to 1
        E_k_actual_norm = avg_energy_per_dim / (torch.sum(avg_energy_per_dim) + 1e-8)
        E_k_target_norm = E_k_target / (torch.sum(E_k_target) + 1e-8)
        
        loss = torch.mean((E_k_actual_norm - E_k_target_norm)**2)
        
        if torch.isnan(loss):
            return torch.tensor(1e6, device=v_prime.device)
        
        return loss

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

class DataLoader:
    @staticmethod
    def load_csv_data(csv_path, text_column=None, embedding_model='all-MiniLM-L6-v2'):
        """Load data from CSV file and create embeddings if needed."""
        print(f"Loading data from {csv_path}...")
        df = pd.read_csv(csv_path)
        
        if text_column:
            # Create embeddings from text column
            try:
                from sentence_transformers import SentenceTransformer
                print(f"Creating embeddings from column '{text_column}' using {embedding_model}...")
                model = SentenceTransformer(embedding_model)
                texts = df[text_column].astype(str).tolist()
                embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
                return embeddings.astype(np.float32), f"Text embeddings from {csv_path}"
            except ImportError:
                raise ImportError("sentence-transformers package is required for text embeddings. Install with: pip install sentence-transformers")
        else:
            # Assume all numeric columns are features
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            if len(numeric_cols) == 0:
                raise ValueError("No numeric columns found. Please specify text_column for text-based embeddings.")
            data = df[numeric_cols].values.astype(np.float32)
            return data, f"Numeric features from {csv_path}"

class EvaluationMetrics:
    @staticmethod
    def calculate_cumulative_energy(vectors):
        """Calculate cumulative energy fraction."""
        squared_vectors = vectors**2
        avg_energy_per_dim = np.mean(squared_vectors, axis=0)
        total_avg_energy = np.sum(avg_energy_per_dim)
        if total_avg_energy == 0:
            return np.zeros(vectors.shape[1])
        cumulative_energy = np.cumsum(avg_energy_per_dim) / total_avg_energy
        return cumulative_energy

    @staticmethod
    def evaluate_distance_preservation(data_orig, data_transformed, n_pairs_check=1000):
        """Calculate distance preservation metrics."""
        n_samples = min(data_orig.shape[0], data_transformed.shape[0])
        
        # Handle dimension mismatch
        orig_dim = data_orig.shape[1]
        trans_dim = data_transformed.shape[1]
        
        if orig_dim != trans_dim:
            if trans_dim > orig_dim:
                padding = np.zeros((data_orig.shape[0], trans_dim - orig_dim), dtype=np.float32)
                data_orig = np.concatenate((data_orig, padding), axis=1)
            else:
                data_orig = data_orig[:, :trans_dim]
        
        # Sample pairs
        indices_to_check = list(itertools.combinations(range(n_samples), 2))
        if len(indices_to_check) > n_pairs_check:
            sample_indices = np.random.choice(len(indices_to_check), n_pairs_check, replace=False)
            indices_to_check = [indices_to_check[i] for i in sample_indices]
        
        original_distances = []
        transformed_distances = []
        
        for i, j in indices_to_check:
            dist_orig = np.linalg.norm(data_orig[i] - data_orig[j])
            dist_transformed = np.linalg.norm(data_transformed[i] - data_transformed[j])
            
            if dist_orig > 1e-9 or dist_transformed > 1e-9:
                original_distances.append(dist_orig)
                transformed_distances.append(dist_transformed)
        
        if not original_distances:
            return {'mean_abs_diff': np.nan, 'max_abs_diff': np.nan, 'corr_coef': np.nan}
        
        original_distances = np.array(original_distances)
        transformed_distances = np.array(transformed_distances)
        dist_diff = np.abs(original_distances - transformed_distances)
        corr_coef = np.corrcoef(original_distances, transformed_distances)[0, 1] if len(original_distances) > 1 else 1.0
        
        return {
            'mean_abs_diff': np.mean(dist_diff),
            'max_abs_diff': np.max(dist_diff),
            'corr_coef': corr_coef
        }

    @staticmethod
    def compute_knn_metrics(query_orig, query_transformed, search_orig, search_transformed, k, sample_fraction=0.1):
        """Compute KNN-related metrics."""
        n_queries = min(query_orig.shape[0], query_transformed.shape[0])
        n_search = min(search_orig.shape[0], search_transformed.shape[0])
        if n_queries < 1 or n_search < k + 1:
            return {
                'recall': np.nan,
                'mean_true_knn_dist': np.nan,
                'mean_pred_knn_dist': np.nan,
                'mean_true_kth_dist': np.nan,
                'mean_pred_kth_dist': np.nan
            }
        
        num_queries = max(1, int(sample_fraction * n_queries))
        query_indices = np.random.choice(n_queries, num_queries, replace=False)
        
        # Handle dimension mismatch
        orig_dim = search_orig.shape[1]
        trans_dim = search_transformed.shape[1]
        if orig_dim != trans_dim:
            if trans_dim > orig_dim:
                padding = np.zeros((search_orig.shape[0], trans_dim - orig_dim), dtype=np.float32)
                search_orig = np.concatenate((search_orig, padding), axis=1)
            else:
                search_orig = search_orig[:, :trans_dim]
        
        # Fit KNN models once
        nn_orig = NearestNeighbors(n_neighbors=k+1).fit(search_orig)
        nn_trans = NearestNeighbors(n_neighbors=k+1).fit(search_transformed)
        
        recalls = []
        true_knn_dists = []
        pred_knn_dists = []
        true_kth_dists = []
        pred_kth_dists = []
        
        for idx in query_indices:
            # Find neighbors in both spaces
            dists_orig, idx_orig = nn_orig.kneighbors(query_orig[idx].reshape(1, -1))
            _, idx_trans = nn_trans.kneighbors(query_transformed[idx].reshape(1, -1))
            
            # Exclude self (first neighbor)
            true_neighbor_indices = idx_orig[0][1:]
            pred_neighbor_indices = idx_trans[0][1:]
            
            # Compute recall
            neighbors_orig = set(true_neighbor_indices)
            neighbors_trans = set(pred_neighbor_indices)
            intersection = neighbors_orig & neighbors_trans
            recall = len(intersection) / k if k > 0 else 0.0
            recalls.append(recall)
            
            # Compute distances in original space
            dists_true = np.linalg.norm(search_orig[true_neighbor_indices] - query_orig[idx], axis=1)
            dists_pred = np.linalg.norm(search_orig[pred_neighbor_indices] - query_orig[idx], axis=1)
            
            true_knn_dists.append(np.mean(dists_true))
            pred_knn_dists.append(np.mean(dists_pred))
            true_kth_dists.append(dists_true[-1])
            pred_kth_dists.append(dists_pred[-1])
        
        return {
            'recall': float(np.mean(recalls)),
            'mean_true_knn_dist': float(np.mean(true_knn_dists)),
            'mean_pred_knn_dist': float(np.mean(pred_knn_dists)),
            'mean_true_kth_dist': float(np.mean(true_kth_dists)),
            'mean_pred_kth_dist': float(np.mean(pred_kth_dists))
        }

# =============================================================================
# MAIN TRAINER CLASS
# =============================================================================

class LinearTransformTrainer:
    def __init__(self, config):
        self.config = config
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.results = {}
        self.trained_models = {}
        
        # Set random seeds
        np.random.seed(config['dataset'].RANDOM_SEED)
        torch.manual_seed(config['dataset'].RANDOM_SEED)
        random.seed(config['dataset'].RANDOM_SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(config['dataset'].RANDOM_SEED)
        
        print(f"Using device: {self.device}")
        if torch.cuda.is_available():
            print(f"GPU: {torch.cuda.get_device_name()}")
        elif torch.backends.mps.is_available():
            print("GPU: Apple Silicon (MPS)")
    
    def load_and_prepare_data(self, csv_path, text_column=None):
        """Load data from CSV and prepare for training."""
        # Load data
        features, dataset_description = DataLoader.load_csv_data(
            csv_path, text_column, self.config['dataset'].EMBEDDING_MODEL
        )
        
        print(f"Loaded {dataset_description}. Shape: {features.shape}")
        
        # Prepare data
        if self.config['dataset'].STANDARDIZE_DATA:
            print("Centering data...")
            scaler = StandardScaler(with_std=self.config['dataset'].SCALE_DATA)
            data_centered = scaler.fit_transform(features)
        else:
            data_centered = features
        
        # Train/test split
        train_fraction = self.config['dataset'].TRAIN_FRACTION
        eval_fraction = self.config['dataset'].EVAL_FRACTION
        
        if train_fraction + eval_fraction > 1.0:
            raise ValueError(f"Train fraction ({train_fraction}) + Eval fraction ({eval_fraction}) > 1.0")
        
        total_samples = data_centered.shape[0]
        train_size = int(train_fraction * total_samples)
        eval_size = int(eval_fraction * total_samples)
        
        print(f"Splitting data: {train_fraction*100:.0f}% train ({train_size} samples), {eval_fraction*100:.0f}% eval ({eval_size} samples)")
        
        # Split without shuffling
        train_indices = np.arange(train_size)
        eval_indices = np.arange(train_size, train_size + eval_size)
        
        data_train = data_centered[train_indices]
        data_test = data_centered[eval_indices]
        
        print(f"Train samples: {data_train.shape[0]}, Test samples: {data_test.shape[0]}")
        
        # Store data info
        self.orig_dim = features.shape[1]
        self.data_train = data_train
        self.data_test = data_test
        self.dataset_description = dataset_description
    
    def train_model(self, model_name, model, data_tensor_train, data_tensor_test, 
                   loss_fn, epochs, lr, ortho_lambda=None):
        """Train a single linear transform model."""
        print(f"\n--- Training {model_name} ---")
        start_time = time.time()
        
        # Create dataloader
        dataset = torch.utils.data.TensorDataset(data_tensor_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.config['training'].BATCH_SIZE, shuffle=True)
        
        # Setup optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        # Training loop
        model.to(self.device)
        training_history = {'epoch': [], 'loss': [], 'ortho_penalty': []}
        
        for epoch in range(epochs):
            model.train()
            epoch_loss = 0.0
            epoch_ortho_penalty = 0.0
            valid_batches = 0
            
            for batch_data, in dataloader:
                batch_data = batch_data.to(self.device)
                optimizer.zero_grad()
                
                transformed = model(batch_data)
                
                loss = loss_fn(transformed)
                ortho_penalty = 0.0
                
                # Add orthogonality penalty for penalty-based methods
                if ortho_lambda is not None and hasattr(model, 'orthogonality_penalty'):
                    ortho_penalty = model.orthogonality_penalty()
                    loss += ortho_lambda * ortho_penalty
                
                if not (torch.isnan(loss) or torch.isinf(loss)):
                    loss.backward()
                    
                    # Apply manifold gradient projection if available
                    if hasattr(model, 'project_gradient_to_tangent'):
                        with torch.no_grad():
                            for param in model.parameters():
                                if param.grad is not None:
                                    param.grad = model.project_gradient_to_tangent(param.grad)
                    
                    optimizer.step()
                    
                    # Apply orthogonality constraints
                    if hasattr(model, 'project_to_orthogonal'):
                        model.project_to_orthogonal()
                    elif hasattr(model, 'retract_to_manifold'):
                        model.retract_to_manifold()
                    
                    epoch_loss += loss.item()
                    epoch_ortho_penalty += ortho_penalty
                    valid_batches += 1
            
            if valid_batches > 0:
                epoch_loss /= valid_batches
                epoch_ortho_penalty /= valid_batches
                
                training_history['epoch'].append(epoch + 1)
                training_history['loss'].append(epoch_loss)
                training_history['ortho_penalty'].append(epoch_ortho_penalty)
                
                if (epoch + 1) % 20 == 0 or epoch == 0 or epoch == epochs - 1:
                    print(f"{model_name} Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.6f}, Ortho Penalty: {epoch_ortho_penalty:.6f}")
        
        training_time = time.time() - start_time
        print(f"{model_name} training complete. Time: {training_time:.2f}s")
        
        # Inference
        model.eval()
        with torch.no_grad():
            data_train_device = data_tensor_train.to(self.device)
            data_test_device = data_tensor_test.to(self.device)
            
            transformed_train = model(data_train_device).cpu().numpy()
            transformed_test = model(data_test_device).cpu().numpy()
        
        return transformed_train, transformed_test, training_time, training_history, model
    
    def run_all_experiments(self):
        """Run all linear transform experiments."""
        print("\n=== Running All Linear Transform Experiments ===")
        
        # Prepare tensors
        data_tensor_train = torch.from_numpy(self.data_train).float()
        data_tensor_test = torch.from_numpy(self.data_test).float()
        
        # Models to train
        models_to_train = [
            ('SVD Projection', 'svd_projection', data_tensor_train, data_tensor_test),
            ('Penalty Method', 'penalty', data_tensor_train, data_tensor_test),
            ('Manifold Optimization', 'manifold', data_tensor_train, data_tensor_test),
            ('Exponential Map', 'exponential', data_tensor_train, data_tensor_test),
            ('Cayley Transform', 'cayley', data_tensor_train, data_tensor_test),
            ('Givens Rotations', 'givens', data_tensor_train, data_tensor_test),
            ('Householder Reflections', 'householder', data_tensor_train, data_tensor_test),
        ]
        
        for model_name, model_type, train_tensor, test_tensor in models_to_train:
            try:
                # Create model
                if model_type == 'svd_projection':
                    model = SVDProjectionTransform(self.orig_dim)
                    ortho_lambda = None
                elif model_type == 'penalty':
                    model = PenaltyTransform(self.orig_dim)
                    ortho_lambda = self.config['penalty'].LAMBDA
                elif model_type == 'manifold':
                    model = ManifoldTransform(self.orig_dim)
                    ortho_lambda = None
                elif model_type == 'exponential':
                    model = ExponentialMapTransform(self.orig_dim)
                    ortho_lambda = None
                elif model_type == 'cayley':
                    model = CayleyTransform(self.orig_dim)
                    ortho_lambda = None
                elif model_type == 'givens':
                    model = GivensRotationTransform(self.orig_dim)
                    ortho_lambda = None
                elif model_type == 'householder':
                    model = HouseholderTransform(self.orig_dim)
                    ortho_lambda = None
                
                # Choose loss function
                if self.config['training'].LOSS_TYPE == 'cumulative_energy':
                    loss_fn = lambda v: LossFunctions.cumulative_energy_shape(v, self.config['training'].TARGET_A)
                elif self.config['training'].LOSS_TYPE == 'exponential':
                    loss_fn = lambda v: LossFunctions.exponential_shape_loss(v, self.config['training'].TARGET_A)
                else:
                    raise ValueError(f"Unknown loss type: {self.config['training'].LOSS_TYPE}")
                
                # Train model
                train_data, test_data, train_time, history, trained_model = self.train_model(
                    model_name, model, train_tensor, test_tensor,
                    loss_fn, self.config['training'].EPOCHS, self.config['training'].LEARNING_RATE,
                    ortho_lambda
                )
                
                # Store the trained model
                self.trained_models[model_name] = trained_model
                
                # Limit test data for evaluation
                test_data = test_data[:50000]
                
                self.results[model_name] = {
                    'train': train_data,
                    'test': test_data,
                    'time': train_time,
                    'history': history
                }
                
            except Exception as e:
                print(f"Error training {model_name}: {str(e)}")
                self.results[model_name] = {
                    'train': None,
                    'test': None,
                    'time': np.nan,
                    'history': None
                }
        
        return self.results
    
    def evaluate_results(self):
        """Evaluate all transformation results."""
        print("\n=== Evaluating Results ===")
        evaluation_results = {}
        
        # Prepare full search space for kNN recall
        full_orig = np.concatenate([self.data_train, self.data_test], axis=0)
        
        for method_name, result in self.results.items():
            if result['train'] is None or result['test'] is None:
                print(f"Skipping {method_name} (training failed)")
                continue
            
            print(f"\nEvaluating {method_name}...")
            
            # Full search space for transformed data
            full_transformed = np.concatenate([result['train'], result['test']], axis=0)
            
            test_metrics = {}
            
            # Distance preservation
            test_metrics['distance'] = EvaluationMetrics.evaluate_distance_preservation(
                self.data_test, result['test'], self.config['training'].N_PAIRS_CHECK
            )
            
            # Energy analysis
            test_metrics['energy'] = EvaluationMetrics.calculate_cumulative_energy(result['test'])
            
            # KNN recall
            knn_metrics = EvaluationMetrics.compute_knn_metrics(
                self.data_test, result['test'],
                full_orig, full_transformed,
                self.config['evaluation'].K_FOR_RECALL,
                self.config['evaluation'].KNN_SAMPLE_FRACTION
            )
            test_metrics['knn_recall'] = knn_metrics['recall']
            test_metrics['knn_distance_stats'] = knn_metrics
            
            # Print KNN statistics
            print(f"    [KNN@{self.config['evaluation'].K_FOR_RECALL}] Recall: {knn_metrics['recall']:.4f}")
            print(f"    [KNN@{self.config['evaluation'].K_FOR_RECALL}] Mean true KNN dist: {knn_metrics['mean_true_knn_dist']:.4f}, "
                  f"Mean pred KNN dist: {knn_metrics['mean_pred_knn_dist']:.4f}")
            
            evaluation_results[method_name] = {
                'test_metrics': test_metrics,
                'time': result['time'],
                'output_dim': result['test'].shape[1] if result['test'] is not None else None
            }
        
        return evaluation_results
    
    def get_available_methods(self):
        """Get list of available transformation methods."""
        return [name for name, result in self.results.items() 
                if result['train'] is not None and result['test'] is not None]
