"""
Example script for running intrinsic dimensionality analysis on the NInFEA dataset.

This script demonstrates how to:
1. Load the preprocessed NInFEA dataset
2. Prepare it for multimodal analysis
3. Run LARRP analysis on the multimodal data

Usage:
    python 051_ninfea_multimodal_analysis.py --epochs 1000 --gpu 0
"""

import argparse
import os
import sys
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.ninfea_loader import (
    load_ninfea_dataset, 
    prepare_ninfea_for_multimodal_analysis,
    get_ninfea_data_info,
    prepare_ninfea_on_demand
)
from src.models.larrp_ninfea import AdaptiveRankReducedAE_NInFEA
from src.functions.train_larrp_ninfea import train_overcomplete_ae_with_pretrained
from src.data.loading import PaddedMultimodalDataWithMasks, PairedMultimodalData

def get_max_shapes_per_modality(dataset):
    """
    Analyze dataset to find maximum and minimum shapes for each modality across all records.
    
    Args:
        dataset: List of records with multimodal data
        
    Returns:
        tuple: (max_shapes_dict, min_shapes_dict) with modality names as keys and shapes as values
    """
    shapes = {}
    for record in dataset:
        for key, value in record.items():
            if key not in ['record_id', 'sampling_freq_hz']:
                if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray):
                    shape = value.shape
                elif isinstance(value, list):
                    shape = (len(value),)
                else:
                    shape = (1,)
                if key not in shapes:
                    shapes[key] = set()
                shapes[key].add(shape)
    
    # Calculate maximum and minimum shapes for each modality
    max_shapes_per_modality = {key: None for key in shapes.keys()}
    min_shapes_per_modality = {key: None for key in shapes.keys()}
    
    for key, shape_set in shapes.items():
        if len(shape_set) == 1:
            shape = list(shape_set)[0]
            max_shapes_per_modality[key] = shape
            min_shapes_per_modality[key] = shape
        else:
            # Find the maximum and minimum shape for each dimension
            max_shape = []
            min_shape = []
            for shape in shape_set:
                if max_shape == []:
                    max_shape = list(shape)
                    min_shape = list(shape)
                else:
                    for i in range(len(shape)):
                        if i < len(max_shape):
                            max_shape[i] = max(max_shape[i], shape[i])
                            min_shape[i] = min(min_shape[i], shape[i])
                        else:
                            max_shape.append(shape[i])
                            min_shape.append(shape[i])
            max_shapes_per_modality[key] = tuple(max_shape)
            min_shapes_per_modality[key] = tuple(min_shape)
    
    return max_shapes_per_modality, min_shapes_per_modality

def prepare_padded_modalities(dataset, max_shapes_per_modality):
    """
    Prepare padded tensors for each modality where all samples have the same shape.
    
    Args:
        dataset: List of records with multimodal data
        max_shapes_per_modality: Dictionary with maximum shapes for each modality
        
    Returns:
        tuple: (padded_modalities, sequence_lengths)
            - padded_modalities: Dictionary with modality names as keys and padded tensors as values
            - sequence_lengths: Dictionary with modality names as keys and list of sequence lengths as values
    """
    n_samples = len(dataset)
    padded_modalities = {}
    sequence_lengths = {}
    
    # Initialize zero tensors and sequence length lists for each modality
    for modality, max_shape in max_shapes_per_modality.items():
        sequence_lengths[modality] = []
        # Create padded tensor: (n_samples, *max_shape)
        # squeeze any dimensions of size 1
        #max_shape = tuple(dim for dim in max_shape if dim > 1)
        # Keep 3D shape as is - don't flatten for 2D convolutions
        full_shape = (n_samples,) + max_shape
        padded_modalities[modality] = torch.zeros(full_shape, dtype=torch.float32)
    
    # Fill padded tensors with actual data
    for sample_idx, record in enumerate(dataset):
        for modality in max_shapes_per_modality.keys():
            if modality in record and record[modality] is not None:
                data = record[modality]
                # squeeze any dimensions of size 1
                #if isinstance(data, torch.Tensor):
                #    data = data.squeeze()
                #elif isinstance(data, np.ndarray):
                #    data = np.squeeze(data)
                # Keep 3D data as is - don't flatten for 2D convolutions
                
                # Convert to tensor if needed
                if isinstance(data, np.ndarray):
                    data = torch.from_numpy(data).float()
                elif not isinstance(data, torch.Tensor):
                    continue  # Skip non-tensor data
                
                # Get the actual shape of this sample's data
                actual_shape = data.shape
                
                # Store sequence length - for 3D data, store (height, width), for 2D store seq_length
                if len(actual_shape) == 3:
                    # 3D data: (height, width, channels) - store (height, width)
                    seq_length = (actual_shape[0], actual_shape[1])
                else:
                    # 2D data: (seq_length, channels) - store seq_length
                    seq_length = actual_shape[0]
                sequence_lengths[modality].append(seq_length)
                
                # Create slice objects for copying data (handle variable dimensions)
                slices = [sample_idx]  # Start with sample index
                for dim, (actual_size, max_size) in enumerate(zip(actual_shape, max_shapes_per_modality[modality])):
                    slices.append(slice(0, actual_size))
                
                # Copy data to the padded tensor
                try:
                    padded_modalities[modality][tuple(slices)] = data
                except Exception as e:
                    print(f"Warning: Could not pad sample {sample_idx} for modality {modality}: {e}")
                    print(f"  Expected max shape: {max_shapes_per_modality[modality]}")
                    print(f"  Actual data shape: {actual_shape}")
            else:
                # No data for this modality and sample - record 0 length
                # Use scalar 0 for missing data regardless of expected dimensionality
                sequence_lengths[modality].append(0)
    
    return padded_modalities, sequence_lengths

def find_middle_line(image_tensor: np.ndarray) -> int:
    """
    Finds the horizontal baseline in a Pulsed-Wave Doppler (PWD) image tensor.

    This function assumes the baseline is the brightest horizontal line in the image.
    It converts the image to grayscale, sums the intensity of each row, and
    returns the index of the row with the highest sum.

    Args:
        image_tensor: A NumPy array with shape (height, width, channels)
                      representing the PWD image.

    Returns:
        int: The y-coordinate (row index) of the detected baseline.
    """
    if not isinstance(image_tensor, np.ndarray) or image_tensor.ndim != 3:
        raise ValueError("Input must be a 3D NumPy array (height, width, channels).")

    # Convert the image to grayscale using the luminosity method (0.299R + 0.587G + 0.114B)
    # This reduces the 3 channels to a single intensity value for each pixel.
    grayscale_image = np.dot(image_tensor[...,:3], [0.299, 0.587, 0.114])

    # Sum the pixel intensities along each row (axis=1). The row corresponding
    # to the bright baseline will have the highest sum.
    row_intensity_sums = np.sum(grayscale_image, axis=1)

    # Find the index of the row with the maximum sum.
    # This index is the y-coordinate of the baseline.
    baseline_y = np.argmax(row_intensity_sums)

    return int(baseline_y)

def prepare_cropped_modalities(dataset, min_shapes_per_modality):
    """
    Prepare cropped tensors for each modality where all samples are cropped to minimum shape.
    This eliminates the need for padding and masks.
    
    Args:
        dataset: List of records with multimodal data
        min_shapes_per_modality: Dictionary with minimum shapes for each modality
        
    Returns:
        dict: Dictionary with modality names as keys and cropped tensors as values
    """
    n_samples = len(dataset)
    cropped_modalities = {}
    
    # Initialize tensors for each modality with minimum shapes
    for modality, min_shape in min_shapes_per_modality.items():
        # Create tensor: (n_samples, *min_shape)
        full_shape = (n_samples,) + min_shape
        cropped_modalities[modality] = torch.zeros(full_shape, dtype=torch.float32)
    
    # Fill cropped tensors with data cropped to minimum size
    for sample_idx, record in enumerate(dataset):
        for modality in min_shapes_per_modality.keys():
            if modality in record and record[modality] is not None:
                data = record[modality]
                
                # Convert to tensor if needed
                if isinstance(data, np.ndarray):
                    data = torch.from_numpy(data).float()
                elif not isinstance(data, torch.Tensor):
                    continue  # Skip non-tensor data
                
                # Crop data to minimum shape
                min_shape = min_shapes_per_modality[modality]
                crop_slices = []
                # find the middle line for PWD images
                if modality == 'pwd_image':
                    #print(f"min shape for pwd_image: {min_shape}, actual shape: {data.shape}")
                    baseline_y = find_middle_line(data.numpy())
                    #print(f"  Detected baseline at y={baseline_y}")
                    # center the crop around the baseline
                    start_y = max(0, baseline_y - min_shape[0] // 2)
                    end_y = start_y + min_shape[0]
                    #print(f"  Cropping from y={start_y} to y={end_y}")
                    crop_slices.append(slice(start_y, end_y))
                    crop_slices.append(slice(0, min_shape[1]))
                    crop_slices.append(slice(0, min_shape[2]))
                else:
                    for dim, min_size in enumerate(min_shape):
                        crop_slices.append(slice(0, min_size))
                
                # Crop the data
                try:
                    if modality == 'pwd_image':
                        # make sure end_y does not exceed image height, otherwise we need padding
                        #print(f"  Cropping pwd_image with slices: {crop_slices}")
                        #print(start_y, end_y, data.shape[0])
                        if end_y - start_y != min_shape[0]:
                            #print(f"Warning: Cropping dimensions for pwd_image do not match min_shape {min_shape}, need padding")
                            if start_y == 0:
                                cropped_data = torch.zeros(min_shape, dtype=torch.float32)
                                #print(cropped_data[:data.shape[0], :].shape, data[0:end_y, :min_shape[1]].shape)
                                cropped_data[:data.shape[0], :] = data[0:end_y, :min_shape[1]]
                            elif end_y > data.shape[0]:
                                cropped_data = torch.zeros(min_shape, dtype=torch.float32)
                                #print(cropped_data[:data.shape[0]-start_y, :].shape, data[start_y:data.shape[0], :min_shape[1]].shape)
                                cropped_data[:data.shape[0]-start_y, :] = data[start_y:data.shape[0], :min_shape[1]] 
                            else:
                                raise ValueError("Unexpected cropping dimensions for pwd_image")
                        else:
                            #print(f"  Normal cropping for pwd_image")
                            if end_y > data.shape[0]:
                                #print(f"Warning: Cropping end_y {end_y} exceeds image height {data.shape[0]}")
                                cropped_data = torch.zeros(min_shape, dtype=torch.float32)
                                #print(cropped_data[:data.shape[0]-start_y, :].shape, data[start_y:data.shape[0], :min_shape[1]].shape)
                                cropped_data[:data.shape[0]-start_y, :] = data[start_y:data.shape[0], :min_shape[1]] 
                            else:
                                cropped_data = data[tuple(crop_slices)]
                    else:
                        cropped_data = data[tuple(crop_slices)]
                    cropped_modalities[modality][sample_idx] = cropped_data
                except Exception as e:
                    print(f"Warning: Could not crop sample {sample_idx} for modality {modality}: {e}")
                    print(f"  Expected min shape: {min_shape}")
                    print(f"  Actual data shape: {data.shape}")
            else:
                # No data for this modality and sample - leave as zeros
                pass
    
    return cropped_modalities

def get_modality_statistics(mm_data, sequence_lengths=None):
    # Collect summary statistics on unpadded data for each modality
    print("\nData distribution statistics (unpadded data):")
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"\n{modality} (Modality {modality_idx}):")
        print(f"  - Shape: {tensor.shape}")
        
        # Calculate statistics on actual data (excluding padding)
        total_elements = 0
        total_zeros = 0
        all_nonzero_values = []
        
        for sample_idx in range(tensor.shape[0]):
            # Get the actual sequence length for this sample
            if sequence_lengths is not None:
                seq_len = sequence_lengths[modality][sample_idx]
            else:
                seq_len = tensor.shape[1]
            
            if isinstance(seq_len, tuple):
                # 3D data: use height and width
                h, w = seq_len
                if h > 0 and w > 0:
                    actual_data = tensor[sample_idx, :h, :w, :].flatten()
                    total_elements += actual_data.numel()
                    zeros = (actual_data == 0).sum().item()
                    total_zeros += zeros
                    nonzero_data = actual_data[actual_data != 0]
                    if len(nonzero_data) > 0:
                        all_nonzero_values.extend(nonzero_data.tolist())
            elif isinstance(seq_len, (int, float)) and seq_len > 0:
                # 2D data: use sequence length
                seq_len = int(seq_len)
                actual_data = tensor[sample_idx, :seq_len, :].flatten()
                total_elements += actual_data.numel()
                zeros = (actual_data == 0).sum().item()
                total_zeros += zeros
                nonzero_data = actual_data[actual_data != 0]
                if len(nonzero_data) > 0:
                    all_nonzero_values.extend(nonzero_data.tolist())
        
        # Calculate summary statistics
        if total_elements > 0:
            zero_fraction = total_zeros / total_elements
            print(f"  - Total elements (unpadded): {total_elements:,}")
            print(f"  - Zero fraction: {zero_fraction:.3f} ({total_zeros:,}/{total_elements:,})")
            
            if all_nonzero_values:
                all_nonzero_values = torch.tensor(all_nonzero_values)
                mean_val = all_nonzero_values.mean().item()
                std_val = all_nonzero_values.std().item()
                min_val = all_nonzero_values.min().item()
                max_val = all_nonzero_values.max().item()
                
                print(f"  - Non-zero values: {len(all_nonzero_values):,}")
                print(f"  - Mean (non-zero): {mean_val:.6f}")
                print(f"  - Std (non-zero): {std_val:.6f}")
                print(f"  - Min/Max (non-zero): {min_val:.6f} / {max_val:.6f}")
            else:
                print(f"  - No non-zero values found!")
        else:
            print(f"  - No actual data found (all samples have zero length)")

def normalize_modalities_to_0_1(mm_data, sequence_lengths):
    normalization_params = {}
        
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"Normalizing {modality}...")
        
        # Calculate min/max per channel across all valid (unpadded) data
        n_channels = tensor.shape[-1]
        global_min = torch.full((n_channels,), float('inf'))
        global_max = torch.full((n_channels,), float('-inf'))
        
        # Find global min/max per channel across all samples
        for sample_idx in range(tensor.shape[0]):
            seq_len = sequence_lengths[modality][sample_idx]
            
            if isinstance(seq_len, tuple):
                # 3D data: (height, width, channels)
                h, w = seq_len
                if h > 0 and w > 0:
                    sample_data = tensor[sample_idx, :h, :w, :]  # Shape: (h, w, channels)
                    # Reshape to (h*w, channels) for per-channel min/max
                    sample_data = sample_data.reshape(-1, n_channels)
                    sample_min, _ = sample_data.min(dim=0)
                    sample_max, _ = sample_data.max(dim=0)
                    global_min = torch.minimum(global_min, sample_min)
                    global_max = torch.maximum(global_max, sample_max)
            elif isinstance(seq_len, (int, float)) and seq_len > 0:
                # 2D data: (seq_length, channels)
                seq_len = int(seq_len)
                sample_data = tensor[sample_idx, :seq_len, :]  # Shape: (seq_len, channels)
                sample_min, _ = sample_data.min(dim=0)
                sample_max, _ = sample_data.max(dim=0)
                global_min = torch.minimum(global_min, sample_min)
                global_max = torch.maximum(global_max, sample_max)
        
        # Store normalization parameters
        normalization_params[modality] = {
            'min': global_min,
            'max': global_max,
            'range': global_max - global_min
        }
        
        # Apply normalization to the entire tensor
        # Expand dimensions to match tensor shape for broadcasting
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            min_expanded = global_min.view(1, 1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, 1, -1)
        else:  # 2D data: (batch, seq_length, channels)
            min_expanded = global_min.view(1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, -1)
        
        # Avoid division by zero for constant channels
        range_expanded = torch.where(range_expanded == 0, torch.ones_like(range_expanded), range_expanded)
        
        # Normalize to [0,1]: (x - min) / (max - min)
        tensor_normalized = (tensor - min_expanded) / range_expanded
        
        # Clamp to [0,1] to handle any numerical issues
        tensor_normalized = torch.clamp(tensor_normalized, 0, 1)
        
        # Update the tensor in multimodal_data
        mm_data[modality] = tensor_normalized
        
        print(f"  - {modality}: normalized using per-channel min/max")
    return mm_data

def normalize_modalities_to_0_1_cropped(mm_data):
    """
    Normalize cropped modalities to [0,1] range per channel.
    Since data is cropped, no sequence lengths are needed.
    """
    normalization_params = {}
        
    for modality_idx, (modality, tensor) in enumerate(mm_data.items()):
        print(f"Normalizing {modality}...")
        
        # Calculate min/max per channel across all data (no masking needed)
        n_channels = tensor.shape[-1]
        
        # Reshape tensor to (batch*spatial_dims, channels) for per-channel min/max
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            reshaped = tensor.reshape(-1, n_channels)
        else:  # 2D data: (batch, seq_length, channels) 
            reshaped = tensor.reshape(-1, n_channels)
        
        # Find global min/max per channel
        global_min, _ = reshaped.min(dim=0)
        global_max, _ = reshaped.max(dim=0)
        
        # Store normalization parameters
        normalization_params[modality] = {
            'min': global_min,
            'max': global_max,
            'range': global_max - global_min
        }
        
        # Apply normalization to the entire tensor
        # Expand dimensions to match tensor shape for broadcasting
        if len(tensor.shape) == 4:  # 3D data: (batch, height, width, channels)
            min_expanded = global_min.view(1, 1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, 1, -1)
        else:  # 2D data: (batch, seq_length, channels)
            min_expanded = global_min.view(1, 1, -1)
            range_expanded = (global_max - global_min).view(1, 1, -1)
        
        # Avoid division by zero for constant channels
        range_expanded = torch.where(range_expanded == 0, torch.ones_like(range_expanded), range_expanded)
        
        # Normalize to [0,1]: (x - min) / (max - min)
        tensor_normalized = (tensor - min_expanded) / range_expanded
        
        # Clamp to [0,1] to handle any numerical issues
        tensor_normalized = torch.clamp(tensor_normalized, 0, 1)
        
        # Update the tensor in multimodal_data
        mm_data[modality] = tensor_normalized
        
        print(f"  - {modality}: normalized using per-channel min/max")
    
    return mm_data

def main():
    """
    Main function for NInFEA multimodal training with flexible data preparation.
    
    Supports two data preparation modes:
    - 'padding': Pad all samples to maximum shapes with masking (default)
    - 'cropping': Crop all samples to minimum shapes without masking
    
    The training function automatically detects dataset type and handles masking accordingly.
    """
    parser = argparse.ArgumentParser(description='NInFEA Multimodal Intrinsic Dimensionality Analysis')
    parser.add_argument('--gpu', type=int, default=0, help='GPU to use for computation')
    parser.add_argument('--r_square_threshold', type=float, default=0.05, help='R-squared threshold')
    parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
    parser.add_argument('--data_mode', type=str, choices=['padding', 'cropping'], default='padding', 
                       help='Data preparation mode: padding (with masks) or cropping (to min size)')
    # add a list of integers for the modalities to use (0-3)
    parser.add_argument('--mods', type=int, nargs=2, default=[0,1], help='List of two modality indices to use (0-3)')
    args = parser.parse_args()

    # Set device
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create output directory
    output_dir = f"03_results/ninfea_multimodal/mods{args.mods[0]}-{args.mods[1]}_{args.data_mode}_r2-{args.r_square_threshold}_seed-{args.seed}"
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    print("Loading NInFEA dataset...")
    try:
        # Try to load processed dataset, fallback to on-demand processing
        dataset = load_ninfea_dataset(standardize=True)
        print(f"Loaded {len(dataset)} records")
        
        # Analyze shapes across all records
        max_shapes_per_modality, min_shapes_per_modality = get_max_shapes_per_modality(dataset)
        print("Max shapes per modality:")
        for key, shape in max_shapes_per_modality.items():
            print(f"  - {key}: {shape}")
        print("Min shapes per modality:")
        for key, shape in min_shapes_per_modality.items():
            print(f"  - {key}: {shape}")
        
        # Prepare data based on chosen mode
        if args.data_mode == 'padding':
            print("Preparing padded modalities for training...")
            multimodal_data, sequence_lengths = prepare_padded_modalities(dataset, max_shapes_per_modality)
            
            # Get modality statistics before normalization
            get_modality_statistics(multimodal_data, sequence_lengths)
            
            # Normalize each modality to [0,1] range per channel
            print("\nNormalizing modalities to [0,1] range per channel...")
            multimodal_data = normalize_modalities_to_0_1(multimodal_data, sequence_lengths)
        else:  # cropping mode
            print("Preparing cropped modalities for training...")
            multimodal_data = prepare_cropped_modalities(dataset, min_shapes_per_modality)
            sequence_lengths = None  # No sequence lengths needed for cropped data
            
            # Get modality statistics before normalization
            get_modality_statistics(multimodal_data)
            
            # Normalize each modality to [0,1] range 
            print("\nNormalizing modalities to [0,1] range...")
            multimodal_data = normalize_modalities_to_0_1_cropped(multimodal_data)
        
        # limit the modalities to the selected 2 mods
        all_modality_names = list(multimodal_data.keys())
        selected_modality_names = [all_modality_names[i] for i in args.mods]
        multimodal_data = {modality: multimodal_data[modality] for modality in selected_modality_names}
        if sequence_lengths is not None:
            sequence_lengths = {modality: sequence_lengths[modality] for modality in selected_modality_names}
        
        print(f"Processed modalities shapes ({args.data_mode} mode):")
        for modality, tensor in multimodal_data.items():
            print(f"  - {modality}: {tensor.shape}")
            print(f"    Memory usage: {tensor.numel() * tensor.element_size() / 1024**2:.1f} MB")
            
        if args.data_mode == 'padding' and sequence_lengths is not None:
            print("Sequence lengths per modality:")
            for modality, lengths in sequence_lengths.items():
                print(f"  - {modality}: min={min(lengths)}, max={max(lengths)}, mean={np.mean(lengths):.1f}")
        else:
            print("Data cropped to minimum shapes - no sequence length variation")        
    except FileNotFoundError:
        print("Raw NInFEA data not found. Running data download first...")
        # Run the preparation script
        import subprocess
        prep_script = "02_paper_experiments/050_NInFEA_prep.py"
        subprocess.run([sys.executable, prep_script], check=True)
        
        # Try loading again (will use on-demand processing if file is too large)
        dataset = load_ninfea_dataset(standardize=True)
        multimodal_data = prepare_ninfea_for_multimodal_analysis(
            dataset, max_length=1000
        )
    exit()

    # Prepare modalities for training (keep on CPU for DataLoader)
    modality_names = list(multimodal_data.keys())
    modality_tensors = [multimodal_data[name] for name in modality_names]
    
    print(f"\nTraining multimodal autoencoder with {len(modality_names)} modalities...")
    print(f"Modalities: {modality_names}")
    
    # Create train/val split 
    n_samples = len(next(iter(multimodal_data.values())))
    n_train = int(0.8 * n_samples)
    
    print(f"Total samples: {n_samples}, Training samples: {n_train}")
    
    # Split data for train and validation
    train_tensors = {}
    val_tensors = {}
    
    for modality in modality_names:
        # Split tensors
        train_tensors[modality] = multimodal_data[modality][:n_train]
        val_tensors[modality] = multimodal_data[modality][n_train:]
    
    # Create datasets based on data mode
    if args.data_mode == 'padding':
        # Split sequence lengths for padding mode
        train_sequence_lengths = {}
        val_sequence_lengths = {}
        
        for modality in modality_names:
            train_sequence_lengths[modality] = sequence_lengths[modality][:n_train]
            val_sequence_lengths[modality] = sequence_lengths[modality][n_train:]
        
        # Create padded datasets with masks
        train_dataset = PaddedMultimodalDataWithMasks(
            padded_tensors=train_tensors,
            sequence_lengths=train_sequence_lengths,
            modality_names=modality_names
        )
        
        val_dataset = PaddedMultimodalDataWithMasks(
            padded_tensors=val_tensors,
            sequence_lengths=val_sequence_lengths,
            modality_names=modality_names
        )
    else:  # cropping mode
        # Create cropped datasets without masks - convert to list format
        #train_data_list = []
        #val_data_list = []
        
        #for i in range(n_train):
        #    sample = [train_tensors[modality][i] for modality in modality_names]
        #    train_data_list.append(sample)
            
        #for i in range(n_samples - n_train):
        #    sample = [val_tensors[modality][i] for modality in modality_names]
        #    val_data_list.append(sample)
        
        #train_dataset = PairedMultimodalData(train_data_list,modality_names)
        #val_dataset = PairedMultimodalData(val_data_list,modality_names)
        train_dataset = PairedMultimodalData(train_tensors,modality_names)
        val_dataset = PairedMultimodalData(val_tensors,modality_names)
    
    print(f"Created train dataset with {len(train_dataset)} samples")
    print(f"Created validation dataset with {len(val_dataset)} samples")
    
    # Training configuration
    class TrainArgs:
        def __init__(self):
            self.latent_dim = 100
            self.batch_size = 8
            self.lr = 1e-4 # 0-3, 0-2, 0-1, 1-2
            #self.lr = 1e-3 # 2-3, 1-3
            self.weight_decay = 0
            self.dropout = 0.1
            self.epochs = 5000
            self.ae_depth = 2
            self.conv_depth = 3
            self.ae_hidden = 512
            self.rank_or_sparse = 'rank'
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu
            #self.loss_weights = [0.1, 1.0] # for images

    train_args = TrainArgs()

    # Method hyperparameters
    method_hyperparameters = {
        "r_square_thresholds": args.r_square_threshold,
        "early_stopping": 100,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": 0.01,
        "patiences": 10,
    }

    # Prepare modality tensors for NInFEA training (keep 3D structure for conv layers)
    training_data = []
    for tensor in modality_tensors:
        # NInFEA expects (batch, seq_length, n_channels)
        #if len(tensor.shape) == 2:
        #    # Add channel dimension if missing: (batch, seq_length) -> (batch, seq_length, 1)
        #    tensor = tensor.unsqueeze(-1)
        #elif len(tensor.shape) > 3:
        #    # Flatten extra dimensions: (batch, seq, h, w, ...) -> (batch, seq, h*w*...)
        #    batch_size, seq_length = tensor.shape[:2]
        #    tensor = tensor.view(batch_size, seq_length, -1)
        training_data.append(tensor)

    # Debug: Print tensor shapes for NInFEA training
    print("\nTensor shapes for NInFEA training:")
    for i, tensor in enumerate(training_data):
        print(f"  Modality {i} ({modality_names[i]}): {tensor.shape}")

    # Train the multimodal autoencoder with datasets
    # Pass the train and validation datasets (masked or unmasked based on data mode)
    results = train_overcomplete_ae_with_pretrained(
        train_dataset,
        val_dataset,
        train_args.latent_dim,
        device,
        train_args,
        epochs=train_args.epochs,
        lr=train_args.lr,
        batch_size=train_args.batch_size,
        ae_depth=train_args.ae_depth,
        hidden_dim=train_args.ae_hidden,
        dropout=train_args.dropout,
        wd=train_args.weight_decay,
        conv_depth=train_args.conv_depth,
        early_stopping=method_hyperparameters["early_stopping"],
        initial_rank_ratio=1.0,
        rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
        rank_reduction_threshold=method_hyperparameters["rank_reduction_thresholds"],
        warmup_epochs=method_hyperparameters["early_stopping"],
        patience=method_hyperparameters["patiences"],
        min_rank=1,
        r_square_threshold=method_hyperparameters["r_square_thresholds"],
        threshold_type='absolute',
        verbose=True,
        model_name=None,
        pretrained_name=f"ninfea_pretrained_mods{args.mods[0]}-{args.mods[1]}_seed-{args.seed}",
        lr_schedule='linear'
    )
    
    # Unpack results
    model, representations, train_loss, r_squares, rank_history, loss_curves = results

    print("R² values:", r_squares)
    print("Rank history:", rank_history)
    
    print(f"\nTraining completed!")
    print(f"Final loss: {train_loss:.6f}")
    
    # Handle R² formatting (could be list or single value)
    if isinstance(r_squares, list):
        if r_squares and all(r is not None for r in r_squares):
            r_squares_str = [f"{r:.6f}" for r in r_squares]
            print(f"Final R² per modality: {r_squares_str}")
            print(f"Mean R²: {sum(r_squares)/len(r_squares):.6f}")
        else:
            print(f"Final R²: Not available (no rank reduction performed)")
    else:
        if r_squares is not None:
            print(f"Final R²: {r_squares:.6f}")
        else:
            print(f"Final R²: Not available")
            
    print(f"Final ranks: {rank_history['ranks'][-1] if rank_history['ranks'] else 'N/A'}")
    
    # Save results
    # Handle R² values - could be list or single value
    if isinstance(r_squares, list) and r_squares and all(r is not None for r in r_squares):
        # Use per-modality R² values
        r_square_values = r_squares + [sum(r_squares)/len(r_squares)]  # Add mean for 'total'
    else:
        # Use a default or single value for all
        default_r_square = r_squares if r_squares is not None else 0.0
        r_square_values = [default_r_square] * (len(modality_names) + 1)
    
    # Handle final ranks - parse from string if needed
    if rank_history['ranks']:
        final_ranks_str = rank_history['ranks'][-1]
        try:
            # Try to parse comma-separated rank values
            final_ranks = [int(x.strip()) for x in final_ranks_str.split(',')]
            # Pad or truncate to match number of modalities + total
            while len(final_ranks) < len(modality_names) + 1:
                final_ranks.append(0)
            final_ranks = final_ranks[:len(modality_names) + 1]
        except:
            # Fallback to zeros if parsing fails
            final_ranks = [0] * (len(modality_names) + 1)
    else:
        final_ranks = [0] * (len(modality_names) + 1)
    
    results_df = pd.DataFrame({
        'modality': modality_names + ['total'],
        'final_rank': final_ranks,
        'final_loss': [train_loss] * (len(modality_names) + 1),
        'final_r_square': r_square_values,
        'epochs': [train_args.epochs] * (len(modality_names) + 1),
        'latent_dim': [train_args.latent_dim] * (len(modality_names) + 1)
    })
    
    output_file = os.path.join(output_dir, 'ninfea_multimodal_results.csv')
    results_df.to_csv(output_file, index=False)
    print(f"Results saved to: {output_file}")
    
    # Create visualization with 3 subplots in one row
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # 1. Training and validation loss curves
    # Always use rank_history for loss data as it has both train and val losses
    if rank_history and 'loss' in rank_history and 'val_loss' in rank_history:
        epochs = rank_history.get('epoch', [])
        train_losses = rank_history.get('loss', [])
        val_losses = rank_history.get('val_loss', [])
        
        train_loss_curve, val_loss_curve = loss_curves
        if train_losses:
            #axes[0].plot(epochs, train_losses, 'b-o', linewidth=2, label='Training Loss')
            axes[0].plot(train_loss_curve, 'b-', linewidth=2, label='Training Loss')
        if val_losses:
            #axes[0].plot(epochs, val_losses, 'r-o', linewidth=2, label='Validation Loss')
            axes[0].plot(val_loss_curve, 'r-', linewidth=2, label='Validation Loss')
        
        axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Reconstruction Loss')
        axes[0].grid(True, alpha=0.3)
        axes[0].set_yscale('log')  # Log scale for better visualization
        axes[0].legend()
    elif loss_curves and len(loss_curves) > 0:
        # Fallback to detailed loss curves if rank_history is not available
        axes[0].plot(loss_curves, 'b-', linewidth=2, label='Training Loss')
        axes[0].set_title('Training Loss Curve', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Reconstruction Loss')
        axes[0].grid(True, alpha=0.3)
        axes[0].set_yscale('log')  # Log scale for better visualization
        axes[0].legend()
    else:
        axes[0].text(0.5, 0.5, 'No loss data available', 
                    ha='center', va='center', transform=axes[0].transAxes, fontsize=12)
        axes[0].set_title('Loss Curves', fontsize=14, fontweight='bold')
    
    # 2. Per-subspace rank evolution
    if rank_history and 'ranks' in rank_history and rank_history['ranks']:
        epochs = rank_history.get('epoch', list(range(len(rank_history['ranks']))))
        
        # Parse rank strings to get individual subspace ranks
        try:
            all_ranks_per_epoch = []
            for rank_str in rank_history['ranks']:
                if isinstance(rank_str, str):
                    individual_ranks = [int(x.strip()) for x in rank_str.split(',')]
                    all_ranks_per_epoch.append(individual_ranks)
            
            if all_ranks_per_epoch:
                # Determine number of subspaces from first entry
                num_subspaces = len(all_ranks_per_epoch[0])
                
                # Define subspace names based on actual number found
                if num_subspaces <= len(modality_names) + 1:
                    # Standard case: global + per-modality
                    subspace_names = ['global_shared'] + modality_names[:num_subspaces-1]
                else:
                    # More subspaces than expected - use generic names
                    subspace_names = [f'subspace_{i}' for i in range(num_subspaces)]
                
                colors = plt.cm.tab10(np.linspace(0, 1, num_subspaces))
                
                # Plot each subspace rank evolution
                for subspace_idx in range(num_subspaces):
                    subspace_ranks = [ranks[subspace_idx] for ranks in all_ranks_per_epoch if subspace_idx < len(ranks)]
                    plot_epochs = epochs[:len(subspace_ranks)]
                    
                    if len(subspace_ranks) == 1:
                        # Single point - use scatter
                        axes[1].scatter(plot_epochs[0], subspace_ranks[0], 
                                      color=colors[subspace_idx], 
                                      label=subspace_names[subspace_idx] if subspace_idx < len(subspace_names) else f'subspace_{subspace_idx}',
                                      s=60, alpha=0.8)
                    else:
                        # Multiple points - use line plot
                        axes[1].plot(plot_epochs, subspace_ranks, 
                                   color=colors[subspace_idx], 
                                   label=subspace_names[subspace_idx] if subspace_idx < len(subspace_names) else f'subspace_{subspace_idx}',
                                   marker='o', linewidth=2, markersize=4, alpha=0.8)
                
                axes[1].set_title('Rank Evolution by Subspace', fontsize=14, fontweight='bold')
                axes[1].set_xlabel('Training Epoch')
                axes[1].set_ylabel('Effective Rank')
                axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
                axes[1].grid(True, alpha=0.3)
            else:
                axes[1].text(0.5, 0.5, 'No rank data available', 
                            ha='center', va='center', transform=axes[1].transAxes, fontsize=12)
                axes[1].set_title('Rank Evolution')
                
        except Exception as e:
            axes[1].text(0.5, 0.5, f'Rank parsing error:\n{str(e)}', 
                        ha='center', va='center', transform=axes[1].transAxes, fontsize=10)
            axes[1].set_title('Rank Evolution (Error)')
    else:
        axes[1].text(0.5, 0.5, 'No rank history available', 
                    ha='center', va='center', transform=axes[1].transAxes, fontsize=12)
        axes[1].set_title('Rank Evolution')
    
    # 3. R² scores by modality  
    if isinstance(r_squares, list) and r_squares:
        # Check if we have any non-None R² values
        r_square_per_modality = r_squares[:len(modality_names)] if len(r_squares) >= len(modality_names) else r_squares
        valid_r_squares = [r for r in r_square_per_modality if r is not None]
        
        if valid_r_squares:
            # We have some valid R² values
            colors = plt.cm.Set3(np.linspace(0, 1, len(r_square_per_modality)))
            
            bars = axes[2].bar(range(len(r_square_per_modality)), 
                              [r if r is not None else 0 for r in r_square_per_modality], 
                              color=colors, alpha=0.8, edgecolor='black', linewidth=1)
            
            # Add value labels on bars
            for i, (bar, r_val) in enumerate(zip(bars, r_square_per_modality)):
                height = bar.get_height()
                if r_val is not None:
                    axes[2].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                                f'{r_val:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
                else:
                    axes[2].text(bar.get_x() + bar.get_width()/2., 0.01,
                                'N/A', ha='center', va='bottom', fontsize=10, fontweight='bold')
            
            axes[2].set_title('R² Scores by Modality', fontsize=14, fontweight='bold')
            axes[2].set_xlabel('Modality')
            axes[2].set_ylabel('R² Reconstruction Score')
            axes[2].set_xticks(range(len(r_square_per_modality)))
            # Use modality names if we have the right number, otherwise use indices
            if len(r_square_per_modality) == len(modality_names):
                axes[2].set_xticklabels(modality_names, rotation=45, ha='right')
            else:
                axes[2].set_xticklabels([f'Modality {i}' for i in range(len(r_square_per_modality))], rotation=45, ha='right')
            axes[2].grid(True, alpha=0.3, axis='y')
            axes[2].set_ylim(0, max(valid_r_squares) * 1.1)
        else:
            # All R² values are None
            axes[2].text(0.5, 0.5, 'R² not available\n(all values are None)', 
                        ha='center', va='center', transform=axes[2].transAxes, fontsize=12)
            axes[2].set_title('R² Scores by Modality', fontsize=14, fontweight='bold')
        
    elif r_squares is not None and not isinstance(r_squares, list):
        # Single R² value - show as single bar
        axes[2].bar([0], [r_squares], color='skyblue', alpha=0.8, edgecolor='black', linewidth=1)
        axes[2].text(0, r_squares + 0.01, f'{r_squares:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
        axes[2].set_title('Overall R² Score', fontsize=14, fontweight='bold')
        axes[2].set_ylabel('R² Reconstruction Score')
        axes[2].set_xticks([0])
        axes[2].set_xticklabels(['Overall'])
        axes[2].grid(True, alpha=0.3, axis='y')
    else:
        # No R² data available
        axes[2].text(0.5, 0.5, 'R² not available\n(no rank reduction performed)', 
                    ha='center', va='center', transform=axes[2].transAxes, fontsize=12)
        axes[2].set_title('R² Scores', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plot_file = os.path.join(output_dir, 'ninfea_multimodal_analysis.png')
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    print(f"Plots saved to: {plot_file}")
    plt.close()
    
    print(f"\nAnalysis complete! Results saved in: {output_dir}")

if __name__ == '__main__':
    main()
