"""
Example script for running intrinsic dimensionality analysis on the NInFEA dataset.

This script demonstrates how to:
1. Load the preprocessed NInFEA dataset
2. Prepare it for multimodal analysis
3. Run LARRP analysis on the multimodal data

Usage:
    python 051_ninfea_multimodal_analysis.py --epochs 1000 --gpu 0
"""

import argparse
import os
import sys
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.ninfea_loader import (
    load_ninfea_dataset2, 
    prepare_ninfea_for_multimodal_analysis,
    get_ninfea_data_info,
    prepare_ninfea_on_demand
)
from src.models.larrp_ninfea import AdaptiveRankReducedAE_NInFEA
from src.functions.train_larrp_ninfea import train_overcomplete_ae_with_pretrained
from src.data.loading import PaddedMultimodalDataWithMasks, PairedMultimodalData

def get_max_shapes_per_modality(dataset):
    """
    Analyze dataset to find maximum and minimum shapes for each modality across all records.
    
    Args:
        dataset: List of records with multimodal data
        
    Returns:
        tuple: (max_shapes_dict, min_shapes_dict) with modality names as keys and shapes as values
    """
    shapes = {}
    for record in dataset:
        for key, value in record.items():
            if key not in ['record_id', 'sampling_freq_hz']:
                if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray):
                    shape = value.shape
                elif isinstance(value, list):
                    shape = (len(value),)
                else:
                    shape = (1,)
                if key not in shapes:
                    shapes[key] = set()
                shapes[key].add(shape)
    
    # Calculate maximum and minimum shapes for each modality
    max_shapes_per_modality = {key: None for key in shapes.keys()}
    min_shapes_per_modality = {key: None for key in shapes.keys()}
    
    for key, shape_set in shapes.items():
        if len(shape_set) == 1:
            shape = list(shape_set)[0]
            max_shapes_per_modality[key] = shape
            min_shapes_per_modality[key] = shape
        else:
            # Find the maximum and minimum shape for each dimension
            max_shape = []
            min_shape = []
            for shape in shape_set:
                if max_shape == []:
                    max_shape = list(shape)
                    min_shape = list(shape)
                else:
                    for i in range(len(shape)):
                        if i < len(max_shape):
                            max_shape[i] = max(max_shape[i], shape[i])
                            min_shape[i] = min(min_shape[i], shape[i])
                        else:
                            max_shape.append(shape[i])
                            min_shape.append(shape[i])
            max_shapes_per_modality[key] = tuple(max_shape)
            min_shapes_per_modality[key] = tuple(min_shape)
    
    return max_shapes_per_modality, min_shapes_per_modality

def prepare_padded_modalities(dataset, max_shapes_per_modality):
    """
    Prepare padded tensors for each modality where all samples have the same shape.
    
    Args:
        dataset: List of records with multimodal data
        max_shapes_per_modality: Dictionary with maximum shapes for each modality
        
    Returns:
        tuple: (padded_modalities, sequence_lengths)
            - padded_modalities: Dictionary with modality names as keys and padded tensors as values
            - sequence_lengths: Dictionary with modality names as keys and list of sequence lengths as values
    """
    n_samples = len(dataset)
    padded_modalities = {}
    sequence_lengths = {}
    
    # Initialize zero tensors and sequence length lists for each modality
    for modality, max_shape in max_shapes_per_modality.items():
        sequence_lengths[modality] = []
        # Create padded tensor: (n_samples, *max_shape)
        # squeeze any dimensions of size 1
        #max_shape = tuple(dim for dim in max_shape if dim > 1)
        # Keep 3D shape as is - don't flatten for 2D convolutions
        full_shape = (n_samples,) + max_shape
        padded_modalities[modality] = torch.zeros(full_shape, dtype=torch.float32)
    
    # Fill padded tensors with actual data
    for sample_idx, record in enumerate(dataset):
        for modality in max_shapes_per_modality.keys():
            if modality in record and record[modality] is not None:
                data = record[modality]
                # squeeze any dimensions of size 1
                #if isinstance(data, torch.Tensor):
                #    data = data.squeeze()
                #elif isinstance(data, np.ndarray):
                #    data = np.squeeze(data)
                # Keep 3D data as is - don't flatten for 2D convolutions
                
                # Convert to tensor if needed
                if isinstance(data, np.ndarray):
                    data = torch.from_numpy(data).float()
                elif not isinstance(data, torch.Tensor):
                    continue  # Skip non-tensor data
                
                # Get the actual shape of this sample's data
                actual_shape = data.shape
                
                # Store sequence length - for 3D data, store (height, width), for 2D store seq_length
                if len(actual_shape) == 3:
                    # 3D data: (height, width, channels) - store (height, width)
                    seq_length = (actual_shape[0], actual_shape[1])
                else:
                    # 2D data: (seq_length, channels) - store seq_length
                    seq_length = actual_shape[0]
                sequence_lengths[modality].append(seq_length)
                
                # Create slice objects for copying data (handle variable dimensions)
                slices = [sample_idx]  # Start with sample index
                for dim, (actual_size, max_size) in enumerate(zip(actual_shape, max_shapes_per_modality[modality])):
                    slices.append(slice(0, actual_size))
                
                # Copy data to the padded tensor
                try:
                    padded_modalities[modality][tuple(slices)] = data
                except Exception as e:
                    print(f"Warning: Could not pad sample {sample_idx} for modality {modality}: {e}")
                    print(f"  Expected max shape: {max_shapes_per_modality[modality]}")
                    print(f"  Actual data shape: {actual_shape}")
            else:
                # No data for this modality and sample - record 0 length
                # Use scalar 0 for missing data regardless of expected dimensionality
                sequence_lengths[modality].append(0)
    
    return padded_modalities, sequence_lengths

def prepare_cropped_modalities(dataset, min_shapes_per_modality):
    """
    Prepare cropped tensors for each modality where all samples are cropped to minimum shape.
    This eliminates the need for padding and masks.
    
    Args:
        dataset: List of records with multimodal data
        min_shapes_per_modality: Dictionary with minimum shapes for each modality
        
    Returns:
        dict: Dictionary with modality names as keys and cropped tensors as values
    """
    n_samples = len(dataset)
    cropped_modalities = {}
    
    # Initialize tensors for each modality with minimum shapes
    for modality, min_shape in min_shapes_per_modality.items():
        # Create tensor: (n_samples, *min_shape)
        full_shape = (n_samples,) + min_shape
        cropped_modalities[modality] = torch.zeros(full_shape, dtype=torch.float32)
    
    # Fill cropped tensors with data cropped to minimum size
    for sample_idx, record in enumerate(dataset):
        for modality in min_shapes_per_modality.keys():
            if modality in record and record[modality] is not None:
                data = record[modality]
                
                # Convert to tensor if needed
                if isinstance(data, np.ndarray):
                    data = torch.from_numpy(data).float()
                elif not isinstance(data, torch.Tensor):
                    continue  # Skip non-tensor data
                
                # Crop data to minimum shape
                min_shape = min_shapes_per_modality[modality]
                crop_slices = []
                for dim, min_size in enumerate(min_shape):
                    crop_slices.append(slice(0, min_size))
                
                # Crop the data
                try:
                    cropped_data = data[tuple(crop_slices)]
                    cropped_modalities[modality][sample_idx] = cropped_data
                except Exception as e:
                    print(f"Warning: Could not crop sample {sample_idx} for modality {modality}: {e}")
                    print(f"  Expected min shape: {min_shape}")
                    print(f"  Actual data shape: {data.shape}")
            else:
                # No data for this modality and sample - leave as zeros
                pass
    
    return cropped_modalities

def get_modality_statistics(mm_data, sequence_lengths=None):
    # Collect summary statistics on unpadded data for each modality
    print("\nData distribution statistics (unpadded data):")
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"\n{modality} (Modality {modality_idx}):")
        print(f"  - Shape: {tensor.shape}")
        
        # Calculate statistics on actual data (excluding padding)
        total_elements = 0
        total_zeros = 0
        all_nonzero_values = []
        
        for sample_idx in range(tensor.shape[0]):
            # Get the actual sequence length for this sample
            if sequence_lengths is not None:
                seq_len = sequence_lengths[modality][sample_idx]
            else:
                seq_len = tensor.shape[1]
            
            if isinstance(seq_len, tuple):
                # 3D data: use height and width
                h, w = seq_len
                if h > 0 and w > 0:
                    actual_data = tensor[sample_idx, :h, :w, :].flatten()
                    total_elements += actual_data.numel()
                    zeros = (actual_data == 0).sum().item()
                    total_zeros += zeros
                    nonzero_data = actual_data[actual_data != 0]
                    if len(nonzero_data) > 0:
                        all_nonzero_values.extend(nonzero_data.tolist())
            elif isinstance(seq_len, (int, float)) and seq_len > 0:
                # 2D data: use sequence length
                seq_len = int(seq_len)
                actual_data = tensor[sample_idx, :seq_len, :].flatten()
                total_elements += actual_data.numel()
                zeros = (actual_data == 0).sum().item()
                total_zeros += zeros
                nonzero_data = actual_data[actual_data != 0]
                if len(nonzero_data) > 0:
                    all_nonzero_values.extend(nonzero_data.tolist())
        
        # Calculate summary statistics
        if total_elements > 0:
            zero_fraction = total_zeros / total_elements
            print(f"  - Total elements (unpadded): {total_elements:,}")
            print(f"  - Zero fraction: {zero_fraction:.3f} ({total_zeros:,}/{total_elements:,})")
            
            if all_nonzero_values:
                all_nonzero_values = torch.tensor(all_nonzero_values)
                mean_val = all_nonzero_values.mean().item()
                std_val = all_nonzero_values.std().item()
                min_val = all_nonzero_values.min().item()
                max_val = all_nonzero_values.max().item()
                
                print(f"  - Non-zero values: {len(all_nonzero_values):,}")
                print(f"  - Mean (non-zero): {mean_val:.6f}")
                print(f"  - Std (non-zero): {std_val:.6f}")
                print(f"  - Min/Max (non-zero): {min_val:.6f} / {max_val:.6f}")
            else:
                print(f"  - No non-zero values found!")
        else:
            print(f"  - No actual data found (all samples have zero length)")

def normalize_modalities_to_0_1(mm_data, sequence_lengths):
    normalization_params = {}
        
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"Normalizing {modality}...")
        
        # Calculate min/max per channel across all valid (unpadded) data
        n_channels = tensor.shape[-1]
        global_min = torch.full((n_channels,), float('inf'))
        global_max = torch.full((n_channels,), float('-inf'))
        
        # Find global min/max per channel across all samples
        for sample_idx in range(tensor.shape[0]):
            seq_len = sequence_lengths[modality][sample_idx]
            
            if isinstance(seq_len, tuple):
                # 3D data: (height, width, channels)
                h, w = seq_len
                if h > 0 and w > 0:
                    sample_data = tensor[sample_idx, :h, :w, :]  # Shape: (h, w, channels)
                    # Reshape to (h*w, channels) for per-channel min/max
                    sample_data = sample_data.reshape(-1, n_channels)
                    sample_min, _ = sample_data.min(dim=0)
                    sample_max, _ = sample_data.max(dim=0)
                    global_min = torch.minimum(global_min, sample_min)
                    global_max = torch.maximum(global_max, sample_max)
            elif isinstance(seq_len, (int, float)) and seq_len > 0:
                # 2D data: (seq_length, channels)
                seq_len = int(seq_len)
                sample_data = tensor[sample_idx, :seq_len, :]  # Shape: (seq_len, channels)
                sample_min, _ = sample_data.min(dim=0)
                sample_max, _ = sample_data.max(dim=0)
                global_min = torch.minimum(global_min, sample_min)
                global_max = torch.maximum(global_max, sample_max)
        
        # Store normalization parameters
        normalization_params[modality] = {
            'min': global_min,
            'max': global_max,
            'range': global_max - global_min
        }
        
        # Apply normalization to the entire tensor
        # Expand dimensions to match tensor shape for broadcasting
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            min_expanded = global_min.view(1, 1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, 1, -1)
        else:  # 2D data: (batch, seq_length, channels)
            min_expanded = global_min.view(1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, -1)
        
        # Avoid division by zero for constant channels
        range_expanded = torch.where(range_expanded == 0, torch.ones_like(range_expanded), range_expanded)
        
        # Normalize to [0,1]: (x - min) / (max - min)
        tensor_normalized = (tensor - min_expanded) / range_expanded
        
        # Clamp to [0,1] to handle any numerical issues
        tensor_normalized = torch.clamp(tensor_normalized, 0, 1)
        
        # Update the tensor in multimodal_data
        mm_data[modality] = tensor_normalized
        
        print(f"  - {modality}: normalized using per-channel min/max")
    return mm_data

def normalize_modalities_to_0_1_cropped(mm_data):
    """
    Normalize cropped modalities to [0,1] range per channel.
    Since data is cropped, no sequence lengths are needed.
    """
    normalization_params = {}
        
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"Normalizing {modality}...")
        
        # Calculate min/max per channel across all data (no masking needed)
        n_channels = tensor.shape[-1]
        
        # Reshape tensor to (batch*spatial_dims, channels) for per-channel min/max
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            reshaped = tensor.reshape(-1, n_channels)
        else:  # 2D data: (batch, seq_length, channels) 
            reshaped = tensor.reshape(-1, n_channels)
        
        # Find global min/max per channel
        global_min, _ = reshaped.min(dim=0)
        global_max, _ = reshaped.max(dim=0)
        
        # Store normalization parameters
        normalization_params[modality] = {
            'min': global_min,
            'max': global_max,
            'range': global_max - global_min
        }
        
        # Apply normalization to the entire tensor
        # Expand dimensions to match tensor shape for broadcasting
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            min_expanded = global_min.view(1, 1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, 1, -1)
        else:  # 2D data: (batch, seq_length, channels)
            min_expanded = global_min.view(1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, -1)
        
        # Avoid division by zero for constant channels
        range_expanded = torch.where(range_expanded == 0, torch.ones_like(range_expanded), range_expanded)
        
        # Normalize to [0,1]: (x - min) / (max - min)
        tensor_normalized = (tensor - min_expanded) / range_expanded
        
        # Clamp to [0,1] to handle any numerical issues
        tensor_normalized = torch.clamp(tensor_normalized, 0, 1)
        
        # Update the tensor in multimodal_data
        mm_data[modality] = tensor_normalized
        
        print(f"  - {modality}: normalized using per-channel min/max")
    
    return mm_data

def main():
    """
    Main function for NInFEA multimodal training with flexible data preparation.
    
    Supports two data preparation modes:
    - 'padding': Pad all samples to maximum shapes with masking (default)
    - 'cropping': Crop all samples to minimum shapes without masking
    
    The training function automatically detects dataset type and handles masking accordingly.
    """
    parser = argparse.ArgumentParser(description='NInFEA Multimodal Intrinsic Dimensionality Analysis')
    parser.add_argument('--gpu', type=int, default=0, help='GPU to use for computation')
    parser.add_argument('--r_square_threshold', type=float, default=0.05, help='R-squared threshold')
    parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
    parser.add_argument('--data_mode', type=str, choices=['padding', 'cropping'], default='padding', 
                       help='Data preparation mode: padding (with masks) or cropping (to min size)')
    args = parser.parse_args()

    # Set device
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create output directory
    output_dir = f"03_results/ninfea_multimodal/{args.data_mode}_r2-{args.r_square_threshold}_seed-{args.seed}"
    os.makedirs(output_dir, exist_ok=True)

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    print("Loading NInFEA dataset...")
    try:
        # Try to load processed dataset, fallback to on-demand processing
        dataset = load_ninfea_dataset2(standardize=True)
        print(f"Loaded {len(dataset)} records")
        
        # Analyze shapes across all records
        max_shapes_per_modality, min_shapes_per_modality = get_max_shapes_per_modality(dataset)
        print("Max shapes per modality:")
        for key, shape in max_shapes_per_modality.items():
            print(f"  - {key}: {shape}")
        print("Min shapes per modality:")
        for key, shape in min_shapes_per_modality.items():
            print(f"  - {key}: {shape}")
        
        # Prepare data based on chosen mode
        if args.data_mode == 'padding':
            print("Preparing padded modalities for training...")
            multimodal_data, sequence_lengths = prepare_padded_modalities(dataset, max_shapes_per_modality)
            
            # Get modality statistics before normalization
            get_modality_statistics(multimodal_data, sequence_lengths)
            
            # Normalize each modality to [0,1] range per channel
            print("\nNormalizing modalities to [0,1] range per channel...")
            multimodal_data = normalize_modalities_to_0_1(multimodal_data, sequence_lengths)
        else:  # cropping mode
            print("Preparing cropped modalities for training...")
            multimodal_data = prepare_cropped_modalities(dataset, min_shapes_per_modality)
            sequence_lengths = None  # No sequence lengths needed for cropped data
            
            # Get modality statistics before normalization
            get_modality_statistics(multimodal_data)
            
            # Normalize each modality to [0,1] range 
            print("\nNormalizing modalities to [0,1] range...")
            multimodal_data = normalize_modalities_to_0_1_cropped(multimodal_data)
        
        print(f"Processed modalities shapes ({args.data_mode} mode):")
        for modality, tensor in multimodal_data.items():
            print(f"  - {modality}: {tensor.shape}")
            print(f"    Memory usage: {tensor.numel() * tensor.element_size() / 1024**2:.1f} MB")
            
        if args.data_mode == 'padding' and sequence_lengths is not None:
            print("Sequence lengths per modality:")
            for modality, lengths in sequence_lengths.items():
                print(f"  - {modality}: min={min(lengths)}, max={max(lengths)}, mean={np.mean(lengths):.1f}")
        else:
            print("Data cropped to minimum shapes - no sequence length variation")
        
            
    except FileNotFoundError:
        print("Raw NInFEA data not found. Running data download first...")
        # Run the preparation script
        import subprocess
        prep_script = "02_paper_experiments/050_NInFEA_prep.py"
        subprocess.run([sys.executable, prep_script], check=True)
        
        # Try loading again (will use on-demand processing if file is too large)
        dataset = load_ninfea_dataset2(standardize=True)
        multimodal_data = prepare_ninfea_for_multimodal_analysis(
            dataset, max_length=1000
        )

    # Prepare modalities for training (keep on CPU for DataLoader)
    modality_names = list(multimodal_data.keys())
    # kick out the second modality (maternal ECG)
    modality_names.pop(1)
    modality_tensors = [multimodal_data[name] for name in modality_names]
    
    print(f"\nTraining multimodal autoencoder with {len(modality_names)} modalities...")
    print(f"Modalities: {modality_names}")
    
    # Create train/val split 
    n_samples = len(next(iter(multimodal_data.values())))
    n_train = int(0.8 * n_samples)
    
    print(f"Total samples: {n_samples}, Training samples: {n_train}")
    
    # Split data for train and validation
    train_tensors = {}
    val_tensors = {}
    
    for modality in modality_names:
        # Split tensors
        train_tensors[modality] = multimodal_data[modality][:n_train]
        val_tensors[modality] = multimodal_data[modality][n_train:]
    
    # Create datasets based on data mode
    if args.data_mode == 'padding':
        # Split sequence lengths for padding mode
        train_sequence_lengths = {}
        val_sequence_lengths = {}
        
        for modality in modality_names:
            train_sequence_lengths[modality] = sequence_lengths[modality][:n_train]
            val_sequence_lengths[modality] = sequence_lengths[modality][n_train:]
        
        # Create padded datasets with masks
        train_dataset = PaddedMultimodalDataWithMasks(
            padded_tensors=train_tensors,
            sequence_lengths=train_sequence_lengths,
            modality_names=modality_names
        )
        
        val_dataset = PaddedMultimodalDataWithMasks(
            padded_tensors=val_tensors,
            sequence_lengths=val_sequence_lengths,
            modality_names=modality_names
        )
    else:  # cropping mode
        # Create cropped datasets without masks - convert to list format
        #train_data_list = []
        #val_data_list = []
        
        #for i in range(n_train):
        #    sample = [train_tensors[modality][i] for modality in modality_names]
        #    train_data_list.append(sample)
            
        #for i in range(n_samples - n_train):
        #    sample = [val_tensors[modality][i] for modality in modality_names]
        #    val_data_list.append(sample)
        
        #train_dataset = PairedMultimodalData(train_data_list,modality_names)
        #val_dataset = PairedMultimodalData(val_data_list,modality_names)
        train_dataset = PairedMultimodalData(train_tensors,modality_names)
        val_dataset = PairedMultimodalData(val_tensors,modality_names)
    
    print(f"Created train dataset with {len(train_dataset)} samples")
    print(f"Created validation dataset with {len(val_dataset)} samples")
    
    # Training configuration
    class TrainArgs:
        def __init__(self):
            self.latent_dim = 100
            self.batch_size = 8
            self.lr = 1e-4
            self.weight_decay = 0
            self.dropout = 0.1
            self.epochs = 5000
            self.ae_depth = 2
            self.conv_depth = 3
            self.ae_hidden = 512
            self.rank_or_sparse = 'rank'
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = TrainArgs()

    # Method hyperparameters
    method_hyperparameters = {
        "r_square_thresholds": args.r_square_threshold,
        "early_stopping": 100,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": 0.01,
        "patiences": 10,
    }

    # Prepare modality tensors for NInFEA training (keep 3D structure for conv layers)
    training_data = []
    for tensor in modality_tensors:
        # NInFEA expects (batch, seq_length, n_channels)
        #if len(tensor.shape) == 2:
        #    # Add channel dimension if missing: (batch, seq_length) -> (batch, seq_length, 1)
        #    tensor = tensor.unsqueeze(-1)
        #elif len(tensor.shape) > 3:
        #    # Flatten extra dimensions: (batch, seq, h, w, ...) -> (batch, seq, h*w*...)
        #    batch_size, seq_length = tensor.shape[:2]
        #    tensor = tensor.view(batch_size, seq_length, -1)
        training_data.append(tensor)

    # Debug: Print tensor shapes for NInFEA training
    print("\nTensor shapes for NInFEA training:")
    for i, tensor in enumerate(training_data):
        print(f"  Modality {i} ({modality_names[i]}): {tensor.shape}")

    # Train the multimodal autoencoder with datasets
    # Pass the train and validation datasets (masked or unmasked based on data mode)
    results = train_overcomplete_ae_with_pretrained(
        train_dataset,
        val_dataset,
        train_args.latent_dim,
        device,
        train_args,
        epochs=train_args.epochs,
        lr=train_args.lr,
        batch_size=train_args.batch_size,
        ae_depth=train_args.ae_depth,
        hidden_dim=train_args.ae_hidden,
        dropout=train_args.dropout,
        wd=train_args.weight_decay,
        conv_depth=train_args.conv_depth,
        early_stopping=method_hyperparameters["early_stopping"],
        initial_rank_ratio=1.0,
        rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
        rank_reduction_threshold=method_hyperparameters["rank_reduction_thresholds"],
        warmup_epochs=method_hyperparameters["early_stopping"],
        patience=method_hyperparameters["patiences"],
        min_rank=1,
        r_square_threshold=method_hyperparameters["r_square_thresholds"],
        threshold_type='absolute',
        verbose=True,
        model_name=None,
        pretrained_name=f"ninfea_pretrained_seed-{args.seed}",
        lr_schedule='linear'
    )
    
    # Unpack results
    model, representations, train_loss, r_squares, rank_history, loss_curves = results

    # save the model and representations
    model_file = os.path.join(output_dir, 'ninfea_multimodal_model.pth')
    torch.save(model.state_dict(), model_file)
    print(f"Model saved to: {model_file}")
    for i, rep in enumerate(representations):
        rep_file = os.path.join(output_dir, f'ninfea_multimodal_representation_modality-{i}.npy')
        np.save(rep_file, rep)
        print(f"Representation for rep {i} saved to: {rep_file}")

    #print("R² values:", r_squares)
    #print("Rank history:", rank_history)
    
    print(f"\nTraining completed!")
    print(f"Final loss: {train_loss:.6f}")
    
    # Handle R² formatting (could be list or single value)
    if isinstance(r_squares, list):
        if r_squares and all(r is not None for r in r_squares):
            r_squares_str = [f"{r:.6f}" for r in r_squares]
            print(f"Final R² per modality: {r_squares_str}")
            print(f"Mean R²: {sum(r_squares)/len(r_squares):.6f}")
        else:
            print(f"Final R²: Not available (no rank reduction performed)")
    else:
        if r_squares is not None:
            print(f"Final R²: {r_squares:.6f}")
        else:
            print(f"Final R²: Not available")
            
    print(f"Final ranks: {rank_history['ranks'][-1] if rank_history['ranks'] else 'N/A'}")
    
    # Save results
    # Handle R² values - could be list or single value
    if isinstance(r_squares, list) and r_squares and all(r is not None for r in r_squares):
        # Use per-modality R² values
        r_square_values = r_squares + [sum(r_squares)/len(r_squares)]  # Add mean for 'total'
    else:
        # Use a default or single value for all
        default_r_square = r_squares if r_squares is not None else 0.0
        r_square_values = [default_r_square] * (len(modality_names) + 1)
    
    # Handle final ranks - parse from string if needed
    if rank_history['ranks']:
        final_ranks_str = rank_history['ranks'][-1]
        try:
            # Try to parse comma-separated rank values
            final_ranks = [int(x.strip()) for x in final_ranks_str.split(',')]
            # Pad or truncate to match number of modalities + total
            while len(final_ranks) < len(modality_names) + 1:
                final_ranks.append(0)
            final_ranks = final_ranks[:len(modality_names) + 1]
        except:
            # Fallback to zeros if parsing fails
            final_ranks = [0] * (len(modality_names) + 1)
    else:
        final_ranks = [0] * (len(modality_names) + 1)
    
    results_df = pd.DataFrame({
        'modality': modality_names + ['total'],
        'final_rank': final_ranks,
        'final_loss': [train_loss] * (len(modality_names) + 1),
        'final_r_square': r_square_values,
        'epochs': [train_args.epochs] * (len(modality_names) + 1),
        'latent_dim': [train_args.latent_dim] * (len(modality_names) + 1)
    })
    
    output_file = os.path.join(output_dir, 'ninfea_multimodal_results.csv')
    results_df.to_csv(output_file, index=False)
    print(f"Results saved to: {output_file}")
    
    # Create visualization with 3 subplots in one row
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # 1. Training and validation loss curves
    # Always use rank_history for loss data as it has both train and val losses
    if rank_history and 'loss' in rank_history and 'val_loss' in rank_history:
        epochs = rank_history.get('epoch', [])
        train_losses = rank_history.get('loss', [])
        val_losses = rank_history.get('val_loss', [])
        
        train_loss_curve, val_loss_curve = loss_curves
        if train_losses:
            #axes[0].plot(epochs, train_losses, 'b-o', linewidth=2, label='Training Loss')
            axes[0].plot(train_loss_curve, 'b-', linewidth=2, label='Training Loss')
        if val_losses:
            #axes[0].plot(epochs, val_losses, 'r-o', linewidth=2, label='Validation Loss')
            axes[0].plot(val_loss_curve, 'r-', linewidth=2, label='Validation Loss')
        
        axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Reconstruction Loss')
        axes[0].grid(True, alpha=0.3)
        axes[0].set_yscale('log')  # Log scale for better visualization
        axes[0].legend()
    elif loss_curves and len(loss_curves) > 0:
        # Fallback to detailed loss curves if rank_history is not available
        axes[0].plot(loss_curves, 'b-', linewidth=2, label='Training Loss')
        axes[0].set_title('Training Loss Curve', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Reconstruction Loss')
        axes[0].grid(True, alpha=0.3)
        axes[0].set_yscale('log')  # Log scale for better visualization
        axes[0].legend()
    else:
        axes[0].text(0.5, 0.5, 'No loss data available', 
                    ha='center', va='center', transform=axes[0].transAxes, fontsize=12)
        axes[0].set_title('Loss Curves', fontsize=14, fontweight='bold')
    
    # 2. Per-subspace rank evolution
    if rank_history and 'ranks' in rank_history and rank_history['ranks']:
        epochs = rank_history.get('epoch', list(range(len(rank_history['ranks']))))
        
        # Parse rank strings to get individual subspace ranks
        try:
            all_ranks_per_epoch = []
            for rank_str in rank_history['ranks']:
                if isinstance(rank_str, str):
                    individual_ranks = [int(x.strip()) for x in rank_str.split(',')]
                    all_ranks_per_epoch.append(individual_ranks)
            
            if all_ranks_per_epoch:
                # Determine number of subspaces from first entry
                num_subspaces = len(all_ranks_per_epoch[0])
                
                # Define subspace names based on actual number found
                if num_subspaces <= len(modality_names) + 1:
                    # Standard case: global + per-modality
                    subspace_names = ['global_shared'] + modality_names[:num_subspaces-1]
                else:
                    # More subspaces than expected - use generic names
                    subspace_names = [f'subspace_{i}' for i in range(num_subspaces)]
                
                colors = plt.cm.tab10(np.linspace(0, 1, num_subspaces))
                
                # Plot each subspace rank evolution
                for subspace_idx in range(num_subspaces):
                    subspace_ranks = [ranks[subspace_idx] for ranks in all_ranks_per_epoch if subspace_idx < len(ranks)]
                    plot_epochs = epochs[:len(subspace_ranks)]
                    
                    if len(subspace_ranks) == 1:
                        # Single point - use scatter
                        axes[1].scatter(plot_epochs[0], subspace_ranks[0], 
                                      color=colors[subspace_idx], 
                                      label=subspace_names[subspace_idx] if subspace_idx < len(subspace_names) else f'subspace_{subspace_idx}',
                                      s=60, alpha=0.8)
                    else:
                        # Multiple points - use line plot
                        axes[1].plot(plot_epochs, subspace_ranks, 
                                   color=colors[subspace_idx], 
                                   label=subspace_names[subspace_idx] if subspace_idx < len(subspace_names) else f'subspace_{subspace_idx}',
                                   marker='o', linewidth=2, markersize=4, alpha=0.8)
                
                axes[1].set_title('Rank Evolution by Subspace', fontsize=14, fontweight='bold')
                axes[1].set_xlabel('Training Epoch')
                axes[1].set_ylabel('Effective Rank')
                axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
                axes[1].grid(True, alpha=0.3)
            else:
                axes[1].text(0.5, 0.5, 'No rank data available', 
                            ha='center', va='center', transform=axes[1].transAxes, fontsize=12)
                axes[1].set_title('Rank Evolution')
                
        except Exception as e:
            axes[1].text(0.5, 0.5, f'Rank parsing error:\n{str(e)}', 
                        ha='center', va='center', transform=axes[1].transAxes, fontsize=10)
            axes[1].set_title('Rank Evolution (Error)')
    else:
        axes[1].text(0.5, 0.5, 'No rank history available', 
                    ha='center', va='center', transform=axes[1].transAxes, fontsize=12)
        axes[1].set_title('Rank Evolution')
    
    # 3. R² scores by modality  
    if isinstance(r_squares, list) and r_squares:
        # Check if we have any non-None R² values
        r_square_per_modality = r_squares[:len(modality_names)] if len(r_squares) >= len(modality_names) else r_squares
        valid_r_squares = [r for r in r_square_per_modality if r is not None]
        
        if valid_r_squares:
            # We have some valid R² values
            colors = plt.cm.Set3(np.linspace(0, 1, len(r_square_per_modality)))
            
            bars = axes[2].bar(range(len(r_square_per_modality)), 
                              [r if r is not None else 0 for r in r_square_per_modality], 
                              color=colors, alpha=0.8, edgecolor='black', linewidth=1)
            
            # Add value labels on bars
            for i, (bar, r_val) in enumerate(zip(bars, r_square_per_modality)):
                height = bar.get_height()
                if r_val is not None:
                    axes[2].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                                f'{r_val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
                else:
                    axes[2].text(bar.get_x() + bar.get_width()/2., 0.01,
                                'N/A', ha='center', va='bottom', fontsize=10, fontweight='bold')
            
            axes[2].set_title('R² Scores by Modality', fontsize=14, fontweight='bold')
            axes[2].set_xlabel('Modality')
            axes[2].set_ylabel('R² Reconstruction Score')
            axes[2].set_xticks(range(len(r_square_per_modality)))
            # Use modality names if we have the right number, otherwise use indices
            if len(r_square_per_modality) == len(modality_names):
                axes[2].set_xticklabels(modality_names, rotation=45, ha='right')
            else:
                axes[2].set_xticklabels([f'Modality {i}' for i in range(len(r_square_per_modality))], rotation=45, ha='right')
            axes[2].grid(True, alpha=0.3, axis='y')
            axes[2].set_ylim(0, max(valid_r_squares) * 1.1)
        else:
            # All R² values are None
            axes[2].text(0.5, 0.5, 'R² not available\n(all values are None)', 
                        ha='center', va='center', transform=axes[2].transAxes, fontsize=12)
            axes[2].set_title('R² Scores by Modality', fontsize=14, fontweight='bold')
        
    elif r_squares is not None and not isinstance(r_squares, list):
        # Single R² value - show as single bar
        axes[2].bar([0], [r_squares], color='skyblue', alpha=0.8, edgecolor='black', linewidth=1)
        axes[2].text(0, r_squares + 0.01, f'{r_squares:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
        axes[2].set_title('Overall R² Score', fontsize=14, fontweight='bold')
        axes[2].set_ylabel('R² Reconstruction Score')
        axes[2].set_xticks([0])
        axes[2].set_xticklabels(['Overall'])
        axes[2].grid(True, alpha=0.3, axis='y')
    else:
        # No R² data available
        axes[2].text(0.5, 0.5, 'R² not available\n(no rank reduction performed)', 
                    ha='center', va='center', transform=axes[2].transAxes, fontsize=12)
        axes[2].set_title('R² Scores', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plot_file = os.path.join(output_dir, 'ninfea_multimodal_analysis.png')
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    print(f"Plots saved to: {plot_file}")
    plt.close()
    
    print(f"\nAnalysis complete! Results saved in: {output_dir}")

if __name__ == '__main__':
    main()
