import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
from torch.cuda.amp import autocast, GradScaler
import math
import numpy as np
import tqdm
import gc
from src.visualization.logging import plot_training_state, create_training_movie, plot_training_state_continuous
from src.models.larrp_multimodal_cnn import AdaptiveRankReducedAE_CNN, MMSimData
import matplotlib.pyplot as plt
import pandas as pd
import os
from transformers import T5Tokenizer
from transformers import BertTokenizer

def train_continuous_multimodal_ae(train_loader, val_loader, model, device, latent_dim=None, args=None, 
                                   epochs=100, early_stopping=50, lr=0.001, batch_size=128, wd=1e-5,
                                   initial_rank_ratio=1.0, min_rank=10, rank_schedule=None, 
                                   rank_reduction_frequency=10, rank_reduction_threshold=0.01, 
                                   warmup_epochs=0, patience=10, reduce_on_best_loss='rsquare',
                                   r_square_threshold=0.9, threshold_type='relative',
                                   compressibility_type='linear', reduction_criterion='r_squared',
                                   verbose=True, model_name=None, lr_schedule='constant',
                                   decision_metric='R2', input_shapes=None, end_lr=1e-5, 
                                   save_frequency=None, modality_keys=None, mixed_precision=False,
                                   save_dir=None):
    """
    Train a multimodal autoencoder with continuous modalities (e.g., image-depth, image-image).
    Similar to train_overcomplete_ae but adapted for continuous data without text/tokenization.
    
    Parameters:
    - train_loader, val_loader: DataLoader objects returning dict batches
    - model: Multimodal autoencoder model
    - modality_keys: List of dict keys for modalities (e.g., ['image', 'depth'])
    - save_dir: Directory to save model checkpoints (if None, uses default './03_results/models')
    - All other parameters same as train_overcomplete_ae
    """
    
    # Default modality keys if not provided
    if modality_keys is None:
        modality_keys = ['image', 'depth']
    
    # Default save directory
    if save_dir is None:
        save_dir = './03_results/models'
    
    # Ensure save directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    # Ensure args exists
    if args is None:
        from types import SimpleNamespace
        args = SimpleNamespace()
    
    n_modalities = len(modality_keys)
    multi_gpu = getattr(args, 'multi_gpu', False)
    
    # Initialize epoch counter if not already set
    if not hasattr(model, 'epoch'):
        model.epoch = 0
    
    model.to(device)
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup if needed
    if multi_gpu and not isinstance(model, nn.DataParallel):
        try:
            if hasattr(args, 'gpu_ids') and args.gpu_ids:
                model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
                if verbose:
                    print(f"Using DataParallel with GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Mixed precision scaler
    scaler = GradScaler() if mixed_precision else None
    if mixed_precision and verbose:
        print("Using mixed precision training (BF16)")
    
    # Setup learning rate scheduler with warmup
    scheduler = None
    warmup_scheduler = None
    main_scheduler = None
    
    if warmup_epochs > 0:
        try:
            warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=10.0, end_factor=max(float(end_lr) / float(lr), 0.0), 
                total_iters=warmup_epochs
            )
            print(f"Using LinearLR warmup for {warmup_epochs} epochs from lr {lr*10.0} to {end_lr}")
        except Exception:
            def _warmup_lambda(epoch):
                t = float(min(epoch + 1, warmup_epochs)) / float(max(1, warmup_epochs))
                return 10.0 + (float(end_lr) / float(lr) - 10.0) * t
            warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_warmup_lambda)
            print(f"Using LambdaLR warmup for {warmup_epochs} epochs from lr {lr*10.0} to {end_lr}")
    
    # Main scheduler after warmup
    if lr_schedule == 'cosine':
        main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=end_lr
        )
        print(f"Using CosineAnnealingLR from lr {lr} to {end_lr} over {epochs - warmup_epochs} epochs")
    elif lr_schedule == 'linear':
        try:
            main_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=1.0, end_factor=end_lr / lr, 
                total_iters=max(1, epochs - warmup_epochs)
            )
        except Exception:
            main_scheduler = torch.optim.lr_scheduler.LambdaLR(
                optimizer, lr_lambda=lambda epoch: max(end_lr / lr, 
                1.0 - (epoch + 1) / float(max(1, epochs - warmup_epochs)))
            )
        print(f"Using LinearLR from lr {lr} to {end_lr} over {epochs - warmup_epochs} epochs")
    elif lr_schedule == 'constant' or lr_schedule is None:
        # After warmup, reset to base learning rate and keep it constant
        main_scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer, factor=1.0, total_iters=max(1, epochs - warmup_epochs)
        )
        print(f"Using constant learning rate of {lr} for main training phase")
    else:
        print(f"Using constant learning rate of {lr} for main training phase")
    
    # Combine warmup and main scheduler
    if warmup_epochs > 0 and main_scheduler is not None:
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_scheduler, main_scheduler],
            milestones=[warmup_epochs]
        )
    elif warmup_epochs > 0:
        scheduler = warmup_scheduler
    elif main_scheduler is not None:
        scheduler = main_scheduler
    
    # Default rank reduction schedule if none provided
    if rank_schedule is None:
        rank_schedule = list(range(warmup_epochs + rank_reduction_frequency, 
                                  epochs, rank_reduction_frequency))
    
    initial_squares = [None] * n_modalities
    start_reduction = False
    current_rsquare_per_mod = [None] * n_modalities
    bottom_reached = False
    break_counter = 0
    
    # Initialize plotting variables
    last_batch_data = None
    last_batch_labels = None
    if model_name is not None:
        plot_save_dir = "./03_results/plots/temp_latent_plots/" + model_name
        os.makedirs(plot_save_dir, exist_ok=True)
    else:
        plot_save_dir = None
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    best_loss = float('inf')
    patience_counter = 0
    
    pbar = tqdm.tqdm(range(model.epoch, epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        per_modality_losses = [0.0] * n_modalities
        
        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            
            # Extract modalities from batch
            modality_tensors = [batch[key].to(device, non_blocking=True) for key in modality_keys]
            
            # Store last batch for plotting (clone to avoid gradient issues)
            if plot_save_dir is not None:
                last_batch_data = [x.clone().detach() for x in modality_tensors]
            
            # Forward pass with optional mixed precision
            if mixed_precision:
                with autocast(dtype=torch.bfloat16):
                    x_hat = model(*modality_tensors)
                    
                    # Calculate MSE loss for each modality
                    loss = torch.tensor(0.0, device=device)
                    for i, (recon, original) in enumerate(zip(x_hat, modality_tensors)):
                        mod_loss = F.mse_loss(recon, original, reduction='mean')
                        loss += mod_loss
                        per_modality_losses[i] += mod_loss.item()
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                x_hat = model(*modality_tensors)
                
                # Calculate MSE loss for each modality
                loss = torch.tensor(0.0, device=device)
                for i, (recon, original) in enumerate(zip(x_hat, modality_tensors)):
                    mod_loss = F.mse_loss(recon, original, reduction='mean')
                    loss += mod_loss
                    per_modality_losses[i] += mod_loss.item()
                
                # Backward pass and optimize
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
            
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(train_loader)
        per_modality_losses = [loss / len(train_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            for batch_val in val_loader:
                modality_tensors_val = [batch_val[key].to(device, non_blocking=True) 
                                       for key in modality_keys]
                x_val_hat = model(*modality_tensors_val)
                
                val_batch_loss = 0.0
                for i, (recon, original) in enumerate(zip(x_val_hat, modality_tensors_val)):
                    val_batch_loss += F.mse_loss(recon, original, reduction='mean').item()
                
                val_loss += val_batch_loss
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        # Step the scheduler
        if scheduler is not None:
            scheduler.step()
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        log_dict = {
            'loss': round(train_loss, 4),
            'lr': f'{current_lr:.2e}',
            'mod_losses': [round(l, 3) for l in per_modality_losses],
            'ranks': [layer.active_dims for layer in model.adaptive_layers] 
                     if hasattr(model, 'adaptive_layers') else [],
            'current_rsquare': [round(current_rsquare_per_mod[i], 3) 
                               if current_rsquare_per_mod[i] is not None else 'N/A' 
                               for i in range(n_modalities)],
            'patience': patience_counter,
        }
        pbar.set_postfix(log_dict)
        
        # Update best loss
        #if train_loss < best_loss:
        #    best_loss = train_loss
        #    if reduce_on_best_loss in ['true', 'stagnation']:
        #        patience_counter = 0
        #else:
        #    if reduce_on_best_loss in ['true', 'stagnation']:
        #        patience_counter += 1
        
        # Initialize rank reduction after warmup
        if (start_reduction is False) and (epoch == model.epoch + warmup_epochs):
            rank_history = {
                'total_rank': [model.get_total_rank() if hasattr(model, 'get_total_rank') 
                              else (model.module.get_total_rank() if multi_gpu else 0)],
                'ranks': [', '.join(str(layer.active_dims) for layer in model.adaptive_layers)],
                'epoch': [model.epoch],
                'loss': [train_losses[-1]],
                'val_loss': [val_losses[-1]]
            }
            
            start_reduction = True
            break_counter = 0
            min_rsquares = []
            
            # Compute initial R² values
            if verbose:
                print("   Computing R² from validation data...")
            
            # Use compute_direct_r_squared adapted for continuous modalities
            direct_r_squared_values = compute_direct_r_squared_continuous(
                model, train_loader, device, multi_gpu, modality_keys, verbose=verbose
            )
            
            for i, r_squared_val in enumerate(direct_r_squared_values):
                initial_squares[i] = r_squared_val
                
                if threshold_type == 'relative':
                    min_rsquares.append(r_squared_val * r_square_threshold)
                elif threshold_type == 'absolute':
                    min_rsquares.append(r_squared_val - r_square_threshold)
                else:
                    raise ValueError(f"Unknown threshold_type: {threshold_type}")
                
                current_rsquare_per_mod[i] = r_squared_val
                rank_history[f'rsquare {i}'] = [r_squared_val]
            
            max_rsquares = initial_squares.copy()
            
            if verbose:
                print(f"Initial R-squared values: {[rank_history[f'rsquare {i}'] for i in range(n_modalities)]}, "
                      f"setting {threshold_type} thresholds to {min_rsquares}")
        
        # Apply rank reduction at scheduled epochs
        if (epoch in rank_schedule) and start_reduction and (break_counter == 0):
            # Save model checkpoint
            if model_name:
                latest_path = os.path.join(save_dir, f"{model_name}_latest.pt")
                if multi_gpu:
                    torch.save(model.module.state_dict(), latest_path)
                else:
                    torch.save(model.state_dict(), latest_path)
                if verbose:
                    print(f"Saved checkpoint: {latest_path}")
            
            if reduce_on_best_loss == 'rsquare':
                current_rsquares = []
                modalities_to_reduce = []
                modalities_to_increase = []
                
                # Compute current R² values
                direct_r_squared_values = compute_direct_r_squared_continuous(
                    model, train_loader, device, multi_gpu, modality_keys, verbose=False
                )
                
                for i, r_squared_val in enumerate(direct_r_squared_values):
                    current_rsquares.append(r_squared_val)
                    current_rsquare_per_mod[i] = r_squared_val
                
                r_squares.append(current_rsquares)
                
                # Update max R² values
                update_max = False
                for i, r in enumerate(current_rsquares):
                    if r > max_rsquares[i]:
                        max_rsquares[i] = r
                        update_max = True
                
                if update_max:
                    for i in range(n_modalities):
                        if threshold_type == 'relative':
                            min_rsquares[i] = max_rsquares[i] * r_square_threshold
                        elif threshold_type == 'absolute':
                            min_rsquares[i] = max_rsquares[i] - r_square_threshold
                    print(f"Updated min R² thresholds to {min_rsquares}")
                
                ###
                # determine what modalities to reduce or increase
                ###
                if (len(r_squares) >= min(10, int(patience/2))) and patience_counter >= min(10, int(patience/2)):
                    for i in range(len(current_rsquare_per_mod)):
                        i_rsquares = [r[i] for r in r_squares[-min(10, int(patience/2)):]]
                        
                        # Handle different comparison logic for loss vs R²
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: lower is better, so we reduce if loss is high (above threshold)
                            if all(r > min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i]:  # loss below threshold
                                modalities_to_reduce.append(i)
                        else:
                            # For R²: higher is better (original logic)
                            if all(r < min_rsquares[i] for r in i_rsquares) and not bottom_reached:
                                modalities_to_increase.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                elif (len(r_squares) >= 1):# and (patience_counter >= 1):
                    for i in range(len(current_rsquare_per_mod)):
                        if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                            # For loss: reduce if loss is below threshold (good performance)
                            if current_rsquare_per_mod[i] < min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] > min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)
                        else:
                            # For R²: reduce if R² is above threshold (original logic)
                            if current_rsquare_per_mod[i] > min_rsquares[i]:
                                modalities_to_reduce.append(i)
                            elif current_rsquare_per_mod[i] < min_rsquares[i] and not bottom_reached:
                                modalities_to_increase.append(i)

                # if all modalities can be reduced, we set min and max ranks
                if len(modalities_to_reduce) == len(current_rsquare_per_mod):
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        model.adaptive_layers[i].max_rank = min(sum(current_ranks), max(int(1.5*current_ranks[i]), current_ranks[i]+1), model.adaptive_layers[i].max_rank)
                    print(f"Adjusting maximum ranks to {[layer.max_rank for layer in model.adaptive_layers]}")
                if len(modalities_to_increase) == len(current_rsquare_per_mod):
                    # set minima
                    current_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    for i, cr in enumerate(current_ranks):
                        # if we are increasing all ranks, we can also increase the maximum ranks
                        model.adaptive_layers[i].min_rank = cr
                    #bottom_reached = True
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                
                ###
                # set the patience counters and layers to reduce or increase
                ###
                layers_to_reduce = []
                layers_to_increase = []
                if (len(modalities_to_reduce) == 0) and (len(modalities_to_increase) == 0):
                    #patience_counter += 1
                    pass
                elif (len(modalities_to_reduce) > 0) and (len(modalities_to_increase) > 0):
                    # no increasing yet, but no decreasing the shared either
                    layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                    layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                    # set the min for the modality to be increased to current rank
                    model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                    for i in modalities_to_increase:
                        model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                    print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                else:
                    if len(modalities_to_increase) > 0:
                        if len(modalities_to_increase) == len(current_rsquare_per_mod):
                            # if all modalities are below the threshold, increase ranks of all layers
                            layers_to_increase = [i for i in range(len(model.adaptive_layers))]
                        else:
                            layers_to_increase = [0] + [i + 1 for i in modalities_to_increase]
                            for i in modalities_to_increase:
                                model.adaptive_layers[i + 1].min_rank = model.adaptive_layers[i + 1].active_dims + 1
                            model.adaptive_layers[0].min_rank = model.adaptive_layers[0].active_dims + 1
                            print(f"Adjusting minimum ranks to {[layer.min_rank for layer in model.adaptive_layers]}")
                    if len(modalities_to_reduce) > 0:
                        # if all modalities are below the threshold, reduce ranks of all layers
                        if len(modalities_to_reduce) == len(initial_squares):
                            reduce_shared = True
                            if reduce_shared:
                                layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                        else:
                            # roll a dice whether we also try to reduce the shared layer
                            #if sharedwhenall:
                            #    reduce_shared = False
                            #else:
                            #    reduce_shared = True
                            #if reduce_shared:
                            #    layers_to_reduce = [0] + [i + 1 for i in modalities_to_reduce]
                            #else:
                                layers_to_reduce = [i + 1 for i in modalities_to_reduce]
                if verbose:
                    if compressibility_type == 'direct' and reduction_criterion in ['train_loss', 'val_loss']:
                        print(f"{reduction_criterion} values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                    else:
                        print(f"R-squared values: {current_rsquares}, reducing rank for modalities {modalities_to_reduce}, layers {layers_to_reduce}, increasing rank for modalities {modalities_to_increase}, layers {layers_to_increase}")
                
                
                if verbose:
                    print(f"Rank reduction decision: reduce={layers_to_reduce}, increase={layers_to_increase}")
                
                any_changes_made = False
                if len(layers_to_reduce) > 0:
                    # Apply rank reduction
                    if multi_gpu:
                        changes_made = model.module.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                    else:
                        changes_made = model.reduce_rank(reduction_ratio=0.9, threshold=rank_reduction_threshold, layer_ids=layers_to_reduce)
                    if changes_made:
                        any_changes_made = True
                if len(layers_to_increase) > 0:
                    # Apply rank increase
                    #print(f"Increasing rank for layer {layers_to_increase}")
                    if multi_gpu:
                        changes_made = model.module.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                    else:
                        changes_made = model.increase_rank(increase_ratio=1.1, layer_ids=layers_to_increase)
                    if changes_made:
                        any_changes_made = True
                        break_counter = patience # give model more time to re-learn the added dimensions
                
                if any_changes_made:
                    patience_counter = 0  # Reset patience counter if rank was changed
                else:
                    patience_counter += 1
            
            # Update rank history
            total_rank_after = (model.module.get_total_rank() if multi_gpu 
                              else model.get_total_rank())
            rank_history['total_rank'].append(total_rank_after)
            rank_history['ranks'].append(', '.join(str(layer.active_dims) 
                                                   for layer in model.adaptive_layers))
            rank_history['epoch'].append(epoch)
            
            for i in range(n_modalities):
                if reduce_on_best_loss == 'rsquare':
                    rank_history[f'rsquare {i}'].append(current_rsquare_per_mod[i])
            
            rank_history['loss'].append(train_loss)
            rank_history['val_loss'].append(val_loss)
            
            # Plot training state after rank reduction
            if last_batch_data is not None and plot_save_dir is not None:
                plot_training_state_continuous(model, last_batch_data, last_batch_labels, epoch, 
                                              multi_gpu, plot_save_dir, device, verbose=verbose)
        
        else:
            if (epoch in rank_schedule) and start_reduction and (break_counter > 0):
                break_counter -= 1
        
        # Periodic checkpoint saving
        effective_save_freq = save_frequency if save_frequency is not None else 10
        if model_name is not None and (epoch + 1) % effective_save_freq == 0:
            checkpoint_path = os.path.join(save_dir, f"{model_name}.pt")
            
            # Save model state dict along with active_dims metadata
            checkpoint = {
                'model_state_dict': model.module.state_dict() if multi_gpu else model.state_dict(),
                'active_dims': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') else None,
                'epoch': epoch,
            }
            torch.save(checkpoint, checkpoint_path)
            print(f"Saved checkpoint at epoch {epoch+1} to {checkpoint_path}")
            
            # Also save intermediate reconstruction plots (train & val) every save_frequency epochs
            try:
                n_plot = 6
                # Get a small batch from train and val loaders
                train_batch = next(iter(train_loader))
                val_batch = next(iter(val_loader))
                
                # Extract modalities from batch
                train_modalities = [train_batch[key][:n_plot].to(device) for key in modality_keys]
                val_modalities = [val_batch[key][:n_plot].to(device) for key in modality_keys]
                
                # Run model to get reconstructions
                model.eval()
                with torch.no_grad():
                    if multi_gpu and hasattr(model, 'module'):
                        train_recon = model.module(*train_modalities)
                        val_recon = model.module(*val_modalities)
                    else:
                        train_recon = model(*train_modalities)
                        val_recon = model(*val_modalities)
                model.train()
                
                # Save reconstruction plots for each modality
                for mod_idx, mod_key in enumerate(modality_keys):
                    plot_path = checkpoint_path.replace('./03_results/models/', './03_results/train_plots/').replace('.pt', f'_{mod_key}_recon_epoch{epoch+1}.png')
                    os.makedirs(os.path.dirname(plot_path), exist_ok=True)
                    
                    fig, axes = plt.subplots(4, n_plot, figsize=(n_plot * 1.5, 6))
                    
                    for i in range(n_plot):
                        # Train original
                        orig_img = train_modalities[mod_idx][i].detach().cpu().numpy()
                        if orig_img.shape[0] == 3:  # RGB
                            orig_img = np.transpose(orig_img, (1, 2, 0))
                            axes[0, i].imshow(orig_img)
                        else:  # Single channel (depth)
                            axes[0, i].imshow(orig_img[0], cmap='inferno')
                        axes[0, i].axis('off')
                        if i == 0:
                            axes[0, i].set_ylabel('Train Orig', fontsize=10)
                        
                        # Train reconstruction
                        recon_img = train_recon[mod_idx][i].detach().cpu().numpy()
                        if recon_img.shape[0] == 3:  # RGB
                            recon_img = np.transpose(recon_img, (1, 2, 0))
                            axes[1, i].imshow(recon_img)
                        else:  # Single channel (depth)
                            axes[1, i].imshow(recon_img[0], cmap='inferno')
                        axes[1, i].axis('off')
                        if i == 0:
                            axes[1, i].set_ylabel('Train Recon', fontsize=10)
                        
                        # Val original
                        orig_img = val_modalities[mod_idx][i].detach().cpu().numpy()
                        if orig_img.shape[0] == 3:  # RGB
                            orig_img = np.transpose(orig_img, (1, 2, 0))
                            axes[2, i].imshow(orig_img)
                        else:  # Single channel (depth)
                            axes[2, i].imshow(orig_img[0], cmap='inferno')
                        axes[2, i].axis('off')
                        if i == 0:
                            axes[2, i].set_ylabel('Val Orig', fontsize=10)
                        
                        # Val reconstruction
                        recon_img = val_recon[mod_idx][i].detach().cpu().numpy()
                        if recon_img.shape[0] == 3:  # RGB
                            recon_img = np.transpose(recon_img, (1, 2, 0))
                            axes[3, i].imshow(recon_img)
                        else:  # Single channel (depth)
                            axes[3, i].imshow(recon_img[0], cmap='inferno')
                        axes[3, i].axis('off')
                        if i == 0:
                            axes[3, i].set_ylabel('Val Recon', fontsize=10)
                    
                    plt.tight_layout()
                    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
                    plt.close()
                    print(f"Saved intermediate {mod_key} reconstructions to {plot_path}")
                    
            except Exception as e:
                print(f"Warning: Could not save intermediate reconstruction plots at epoch {epoch+1}: {e}")
        
        # Early stopping
        if patience_counter >= patience:
            if verbose:
                print(f"Early stopping at epoch {epoch} with patience counter {patience_counter} >= patience {patience}")
            break
    
    # Calculate final latent representations
    n_samples = len(train_loader.dataset)
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    
    all_reps = [[] for _ in range(len(final_ranks))]
    all_indices = []
    
    model.eval()
    with torch.no_grad():
        for batch in train_loader:
            modality_tensors = [batch[key].to(device, non_blocking=True) 
                               for key in modality_keys]
            indices = batch.get('idx', batch.get('index', None))
            
            if multi_gpu:
                batch_reps = model.module.encode(*modality_tensors)
            else:
                batch_reps = model.encode(*modality_tensors)
            
            # batch_reps = (h_shared, [h_mod1_spec, h_mod2_spec, ...])
            batch_rep_list = [batch_reps[0].detach().cpu()] + \
                           [batch_reps[1][j].detach().cpu() 
                            for j in range(len(batch_reps[1]))]
            
            for j in range(len(final_ranks)):
                all_reps[j].append(batch_rep_list[j][:, :final_ranks[j]])
            
            if indices is not None:
                all_indices.append(indices)
            
            del modality_tensors, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Concatenate all batches
    if len(all_indices) > 0:
        all_indices = torch.cat(all_indices)
        unique_indices = torch.unique(all_indices, sorted=True)
        idx_to_pos = {int(idx): pos for pos, idx in enumerate(unique_indices)}
        
        reps = [torch.empty((len(unique_indices), final_ranks[i])) 
                for i in range(len(final_ranks))]
        for i, orig_idx in enumerate(all_indices):
            pos = idx_to_pos[int(orig_idx)]
            for j in range(len(reps)):
                reps[j][pos, :] = all_reps[j][i // train_loader.batch_size][i % train_loader.batch_size, :]
        
        sorted_indices = unique_indices.cpu().numpy()
    else:
        # No indices in batch, just concatenate
        reps = [torch.cat(all_reps[i], dim=0) for i in range(len(final_ranks))]
        sorted_indices = np.arange(len(reps[0]))
    
    print(f"Computed final latent representations for training data (n={len(sorted_indices)}).")
    
    # Final plot of training state
    if last_batch_data is not None and plot_save_dir is not None:
        plot_training_state_continuous(model, last_batch_data, last_batch_labels, epoch, 
                                      multi_gpu, plot_save_dir, device, verbose=verbose)
    
    # Create movie from plots at the end of training
    if plot_save_dir is not None:
        create_training_movie(plot_save_dir)
    
    # Save the model
    if model_name:
        model_path = os.path.join(save_dir, f"{model_name}.pt")
        
        # Save final checkpoint with metadata
        checkpoint = {
            'model_state_dict': model.module.state_dict() if multi_gpu else model.state_dict(),
            'active_dims': [layer.active_dims for layer in model.adaptive_layers] if hasattr(model, 'adaptive_layers') else None,
            'epoch': epoch,
            'final_ranks': final_ranks,
        }
        torch.save(checkpoint, model_path)
        if verbose:
            print(f"Saved trained model to {model_path}")
    
    try:
        avg_train_loss = np.mean(train_losses[-5:])
    except:
        avg_train_loss = np.mean(train_losses)
    try:
        last_rsquare = r_squares[-1]
    except:
        last_rsquare = [None] * n_modalities
    
    return model, reps, avg_train_loss, last_rsquare, rank_history, [train_losses, val_losses], sorted_indices


def compute_direct_r_squared_continuous(model, data_loader, device, multi_gpu=False, 
                                       modality_keys=None, verbose=False):
    """
    Compute R² for continuous multimodal data (e.g., image-depth).
    Adapted from compute_direct_r_squared for continuous modalities.
    """
    if modality_keys is None:
        modality_keys = ['image', 'depth']
    
    n_modalities = len(modality_keys)
    model.eval()
    
    # Accumulate statistics across batches
    sum_squared_residuals = [0.0] * n_modalities
    sum_squared_total = [0.0] * n_modalities
    n_samples = 0
    
    # First pass: compute mean per modality
    modality_sums = [0.0] * n_modalities
    with torch.no_grad():
        for batch in data_loader:
            modality_tensors = [batch[key].to(device) for key in modality_keys]
            batch_size = modality_tensors[0].shape[0]
            n_samples += batch_size
            
            for i, tensor in enumerate(modality_tensors):
                modality_sums[i] += tensor.flatten(start_dim=1).sum(dim=0).cpu()
    
    # Compute means
    modality_means = [s / n_samples for s in modality_sums]
    
    # Second pass: compute R²
    with torch.no_grad():
        for batch in data_loader:
            modality_tensors = [batch[key].to(device) for key in modality_keys]
            
            # Get reconstructions
            reconstructions = model(*modality_tensors)
            
            # Compute statistics for each modality
            for i, (original, recon) in enumerate(zip(modality_tensors, reconstructions)):
                orig_flat = original.flatten(start_dim=1).cpu()
                recon_flat = recon.flatten(start_dim=1).cpu()
                
                sum_squared_residuals[i] += ((orig_flat - recon_flat) ** 2).sum().item()
                sum_squared_total[i] += ((orig_flat - modality_means[i]) ** 2).sum().item()
    
    # Compute R² for each modality
    r_squared_values = []
    for i in range(n_modalities):
        if sum_squared_total[i] > 1e-6:
            r_squared = 1.0 - (sum_squared_residuals[i] / (sum_squared_total[i] + 1e-9))
            r_squared = max(0.0, r_squared)
        else:
            r_squared = 0.0
        
        if verbose:
            print(f"Modality {i}: R²={r_squared:.4f}")
        
        r_squared_values.append(r_squared)
    
    return r_squared_values