import os
import copy
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchmetrics.functional import mean_squared_error
from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure

import torchvision.utils as vutils
from PIL import Image
import time


def save_samples_with_border(x, y_pred, y_star, img_shape, save_path, mode='autoencoder', down_scale=2.0):
    """Save sample images with border visualization for super resolution"""
    C, H, W = img_shape
    
    if mode == 'super':
        low_res_H, low_res_W = int(H // down_scale), int(W // down_scale)
        x_low_res = x.view(-1, C, low_res_H, low_res_W)
        batch_size = x_low_res.size(0)
        
        # Create images with border for better visualization
        x_with_border = torch.ones(batch_size, C, H, W) * 1.0
        pad_h, pad_w = (H - low_res_H) // 2, (W - low_res_W) // 2
        
        # Add black border and low res image
        border_width = 1
        x_with_border[:, :, pad_h-border_width:pad_h+low_res_H+border_width, 
                     pad_w-border_width:pad_w+low_res_W+border_width] = 0
        x_with_border[:, :, pad_h:pad_h+low_res_H, pad_w:pad_w+low_res_W] = x_low_res
        
        x_img = x_with_border
    else:
        x_img = x.view(-1, C, H, W)
    
    y_pred_img = y_pred.view(-1, C, H, W)
    y_star_img = y_star.view(-1, C, H, W)
    
    grid = torch.cat([x_img, y_pred_img, y_star_img], dim=0)
    grid = vutils.make_grid(grid, nrow=x.size(0), normalize=True, pad_value=1)
    ndarr = grid.mul(255).add_(0.5).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    Image.fromarray(ndarr).save(save_path)


@torch.no_grad()
def calculate_eval_loss_all_loaders(model, eval_loaders, device):
    """Calculate eval loss on all eval loaders"""
    model.eval()
    eval_losses = []
    
    for eval_loader in eval_loaders:
        total_loss, n_samples = 0, 0
        for x, y_star, _ in eval_loader:
            x, y_star = x.to(device), y_star.to(device)
            y_pred = model(x, y_star)
            total_loss += F.mse_loss(y_pred, y_star, reduction='mean').item() * x.size(0)
            n_samples += x.size(0)
        eval_losses.append(total_loss / n_samples)
    
    return eval_losses

@torch.no_grad()
def evaluate_overall_metrics(model, dataloader, device, img_shape):
    """Calculate overall PSNR and SSIM metrics using both individual and batch methods"""
    model.eval()
    C, H, W = img_shape
    batch_size = 128
    
    # Collect all data
    all_x, all_y_star = [], []
    for x, y_star, _ in dataloader:
        all_x.append(x)
        all_y_star.append(y_star)
    
    all_x = torch.cat(all_x, dim=0)
    all_y_star = torch.cat(all_y_star, dim=0)
    
    # Individual-wise calculation
    total_psnr, total_ssim, n_samples = 0, 0, 0
    for start_idx in range(0, all_x.size(0), batch_size):
        end_idx = min(start_idx + batch_size, all_x.size(0))
        batch_x = all_x[start_idx:end_idx].to(device)
        batch_y_star = all_y_star[start_idx:end_idx].to(device)
        
        batch_y_pred = model(batch_x, batch_y_star)
        batch_y_pred_img = batch_y_pred.view(-1, C, H, W)
        batch_y_star_img = batch_y_star.view(-1, C, H, W)
        
        # Individual sample metrics
        for i in range(batch_x.size(0)):
            sample_pred = batch_y_pred_img[i:i+1]
            sample_star = batch_y_star_img[i:i+1]
            total_psnr += peak_signal_noise_ratio(sample_pred, sample_star, data_range=1.0).item()
            total_ssim += structural_similarity_index_measure(sample_pred, sample_star).item()
            n_samples += 1
    
    psnr_overall = total_psnr / n_samples
    ssim_overall = total_ssim / n_samples
    
    # Batch-wise calculation
    total_psnr_batch, total_ssim_batch, total_samples_batch = 0, 0, 0
    for start_idx in range(0, all_x.size(0), batch_size):
        end_idx = min(start_idx + batch_size, all_x.size(0))
        batch_x = all_x[start_idx:end_idx].to(device)
        batch_y_star = all_y_star[start_idx:end_idx].to(device)
        
        batch_y_pred = model(batch_x, batch_y_star)
        batch_y_pred_img = batch_y_pred.view(-1, C, H, W)
        batch_y_star_img = batch_y_star.view(-1, C, H, W)
        
        current_batch_size = batch_x.size(0)
        batch_psnr = peak_signal_noise_ratio(batch_y_pred_img, batch_y_star_img, data_range=1.0).item()
        batch_ssim = structural_similarity_index_measure(batch_y_pred_img, batch_y_star_img).item()
        
        total_psnr_batch += batch_psnr * current_batch_size
        total_ssim_batch += batch_ssim * current_batch_size
        total_samples_batch += current_batch_size
    
    psnr_batch_overall = total_psnr_batch / total_samples_batch
    ssim_batch_overall = total_ssim_batch / total_samples_batch
    
    return {
        'psnr': psnr_overall,
        'ssim': ssim_overall,
        'psnr_batch': psnr_batch_overall,
        'ssim_batch': ssim_batch_overall
    }


@torch.no_grad()
def evaluate_per_class_metrics(model, dataloader, device, img_shape, num_classes=10):
    """Calculate per-class PSNR and SSIM metrics using both individual and batch methods"""
    model.eval()
    C, H, W = img_shape
    batch_size = 128
    
    # Collect data by class
    class_data = {i: {'x': [], 'y_star': []} for i in range(num_classes)}
    
    for x, y_star, labels in dataloader:
        for i, label in enumerate(labels.tolist()):
            class_data[label]['x'].append(x[i])
            class_data[label]['y_star'].append(y_star[i])
    
    # Convert to tensors and calculate metrics
    class_metrics = {}
    for class_id in range(num_classes):
        if len(class_data[class_id]['x']) == 0:
            class_metrics[class_id] = {
                'psnr': 0, 'ssim': 0,
                'psnr_batch': 0, 'ssim_batch': 0,
                'count': 0
            }
            continue
            
        class_x = torch.stack(class_data[class_id]['x'])
        class_y_star = torch.stack(class_data[class_id]['y_star'])
        
        # Individual-wise calculation
        total_psnr, total_ssim, n_samples = 0, 0, 0
        for start_idx in range(0, class_x.size(0), batch_size):
            end_idx = min(start_idx + batch_size, class_x.size(0))
            batch_x = class_x[start_idx:end_idx].to(device)
            batch_y_star = class_y_star[start_idx:end_idx].to(device)
            
            batch_y_pred = model(batch_x, batch_y_star)
            batch_y_pred_img = batch_y_pred.view(-1, C, H, W)
            batch_y_star_img = batch_y_star.view(-1, C, H, W)
            
            for i in range(batch_x.size(0)):
                sample_pred = batch_y_pred_img[i:i+1]
                sample_star = batch_y_star_img[i:i+1]
                total_psnr += peak_signal_noise_ratio(sample_pred, sample_star, data_range=1.0).item()
                total_ssim += structural_similarity_index_measure(sample_pred, sample_star).item()
                n_samples += 1
        
        # Batch-wise calculation
        total_psnr_batch, total_ssim_batch, total_samples_batch = 0, 0, 0
        for start_idx in range(0, class_x.size(0), batch_size):
            end_idx = min(start_idx + batch_size, class_x.size(0))
            batch_x = class_x[start_idx:end_idx].to(device)
            batch_y_star = class_y_star[start_idx:end_idx].to(device)
            
            batch_y_pred = model(batch_x, batch_y_star)
            batch_y_pred_img = batch_y_pred.view(-1, C, H, W)
            batch_y_star_img = batch_y_star.view(-1, C, H, W)
            
            current_batch_size = batch_x.size(0)
            batch_psnr = peak_signal_noise_ratio(batch_y_pred_img, batch_y_star_img, data_range=1.0).item()
            batch_ssim = structural_similarity_index_measure(batch_y_pred_img, batch_y_star_img).item()
            
            total_psnr_batch += batch_psnr * current_batch_size
            total_ssim_batch += batch_ssim * current_batch_size
            total_samples_batch += current_batch_size
        
        class_metrics[class_id] = {
            'psnr': total_psnr / n_samples,
            'ssim': total_ssim / n_samples,
            'psnr_batch': total_psnr_batch / total_samples_batch,
            'ssim_batch': total_ssim_batch / total_samples_batch,
            'count': n_samples
        }
    
    return class_metrics


@torch.no_grad()
def final_test_evaluation(model, test_loaders, device, img_shape, track_per_class=True, num_classes=10):
    """Final comprehensive evaluation on all test loaders"""
    print(f"\n[INFO] Final test evaluation on {len(test_loaders)} test sets")
    
    all_test_results = []
    
    for i, test_loader in enumerate(test_loaders):
        print(f"      Evaluating test set {i+1}/{len(test_loaders)}")
        
        # Calculate overall metrics
        overall_metrics = evaluate_overall_metrics(model, test_loader, device, img_shape)
        
        # Calculate per-class metrics if needed
        class_metrics = None
        if track_per_class:
            class_metrics = evaluate_per_class_metrics(model, test_loader, device, img_shape, num_classes)
        
        all_test_results.append({
            'overall': overall_metrics,
            'per_class': class_metrics
        })
        
        # Display results
        print(f"        Test{i+1} Overall:")
        print(f"          PSNR: {overall_metrics['psnr']:.4f} | "
              f"PSNR_batch: {overall_metrics['psnr_batch']:.4f}")
        print(f"          SSIM: {overall_metrics['ssim']:.4f} | "
              f"SSIM_batch: {overall_metrics['ssim_batch']:.4f}")
    
    return all_test_results



def train_and_eval(model, 
                   train_loader, 
                   eval_loaders,
                   test_loaders,
                   img_shape, 
                   epochs, 
                   lr, 
                   device,
                   eval_interval,
                   primary_eval_idx=0,
                   model_name=None, 
                   exp_dir=None,
                   use_linear_decay=True,
                   min_lr=0.0,
                   track_per_class=False,
                   num_classes=10,
                   total_params=None, 
                   trainable_params=None, 
                   mode=None, 
                   down_scale=None,
                   # >>> 新增：step-based logging 參數 <<<
                   use_step_logging=False,
                   log_interval_steps=500,
                   eval_interval_steps=2500,
                   save_interval_steps=2500, 
                   use_warmup=False, 
                   warmup_ratio=0.1
                ):
    """
    Training with best model tracking and comprehensive evaluation
    
    Args:
        use_step_logging: 是否使用 step-based logging (適合大型資料集)
        log_interval_steps: 每多少 steps 印出訓練 loss
        eval_interval_steps: 每多少 steps 計算詳細的 eval metrics
        save_interval_steps: 每多少 steps 儲存樣本圖片
        其他參數同原版本
    """
    
    # Ensure inputs are lists
    if not isinstance(eval_loaders, list):
        eval_loaders = [eval_loaders]
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]
    
    logging_mode = "step-based" if use_step_logging else "epoch-based"
    print(f"\n[INFO] Training {model_name} with {len(eval_loaders)} eval sets, {len(test_loaders)} test sets")
    print(f"[INFO] Using {logging_mode} logging")
    print(f"[INFO] Using eval set {primary_eval_idx} as primary for best model selection")
    
    if use_step_logging:
        print(f"[INFO] Step-based intervals: Log={log_interval_steps}, Eval={eval_interval_steps}, Save={save_interval_steps}")
    
    # Setup optimizer and scheduler
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    
    if use_linear_decay:
        if use_step_logging:
            # 計算總 steps
            total_steps = epochs * len(train_loader)
            
            if use_warmup:
                # Warmup + Linear Decay
                warmup_steps = int(total_steps * warmup_ratio)
                
                def warmup_then_decay(step):
                    if step < warmup_steps:
                        # Warmup phase: 0.1 -> 1.0
                        return 0.1 + 0.9 * step / warmup_steps
                    else:
                        # Decay phase: 1.0 -> min_lr/lr
                        remaining_steps = total_steps - warmup_steps
                        decay_step = step - warmup_steps
                        return 1.0 - (1.0 - min_lr/lr) * decay_step / remaining_steps
                
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_then_decay)
                print(f"[INFO] Using Warmup + Linear Decay: {lr} with {warmup_steps} warmup steps -> {min_lr} over {total_steps} steps")
            else:
                # 只有 Linear Decay
                scheduler = torch.optim.lr_scheduler.LinearLR(
                    optimizer, start_factor=1.0, end_factor=min_lr/lr, total_iters=total_steps)
                print(f"[INFO] Using Linear Decay: {lr} -> {min_lr} over {total_steps} steps")
        else:
            # ===== 修復 epoch-based 學習率調度器 =====
            if min_lr == 0.0:
                # 使用 LambdaLR 來實現 epoch-based linear decay to 0
                def linear_decay_to_zero(epoch):
                    return max(0.0, 1.0 - epoch / epochs)
                
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, linear_decay_to_zero)
                print(f"[INFO] Using Linear Decay: {lr} -> {min_lr} over {epochs} epochs (LambdaLR)")
            else:
                # 使用 LinearLR（當 min_lr > 0 時）
                scheduler = torch.optim.lr_scheduler.LinearLR(
                    optimizer, start_factor=1.0, end_factor=min_lr/lr, total_iters=epochs)
                print(f"[INFO] Using Linear Decay: {lr} -> {min_lr} over {epochs} epochs (LinearLR)")
            
            if use_warmup:
                print("[WARNING] Warmup only supported with step-based logging, ignoring warmup for epoch-based training")
    else:
        scheduler = None
        print(f"[INFO] Using constant learning rate: {lr}")
    
    # Create directories
    samples_dir = os.path.join(exp_dir, "samples")
    os.makedirs(samples_dir, exist_ok=True)
    
    # Initialize metric storage
    train_losses = []
    eval_losses_all = []  
    
    # Detailed metrics (calculated every eval_interval or eval_interval_steps)
    detailed_metrics = {
        'psnr': [],
        'ssim': [],
        'psnr_batch': [],
        'ssim_batch': []
    }
    
    # Per-class metrics (only when track_per_class=True)
    per_class_metrics = None
    if track_per_class:
        per_class_metrics = {
            'psnr': {i: [] for i in range(num_classes)},
            'ssim': {i: [] for i in range(num_classes)},
            'psnr_batch': {i: [] for i in range(num_classes)},
            'ssim_batch': {i: [] for i in range(num_classes)}
        }
    
    # Best model tracking
    best_eval_loss = float('inf')
    best_model = None
    best_epoch = 0
    best_step = 0
    
    # Step counting and timing
    global_step = 0
    running_loss = 0.0
    last_log_time = time.time()
    
    # Pre-compute eval loader data for visualization
    primary_eval_loader = eval_loaders[primary_eval_idx]
    eval_batch_data = next(iter(primary_eval_loader))
    x_vis, y_star_vis = eval_batch_data[0], eval_batch_data[1]
    x_vis, y_star_vis = x_vis.to(device), y_star_vis.to(device)
    
    # Training loop
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        
        # ===== 修復：在 epoch 開始時獲取學習率 =====
        if use_linear_decay and scheduler is not None:
            if not use_step_logging:
                # epoch-based: 在 epoch 開始時顯示正確的學習率
                current_lr = optimizer.param_groups[0]['lr']
            else:
                current_lr = optimizer.param_groups[0]['lr']
        else:
            current_lr = lr
        
        # Training phase
        for batch_idx, (x, y_star, _) in enumerate(train_loader):
            x, y_star = x.to(device), y_star.to(device)
            optimizer.zero_grad()
            y_pred = model(x, y_star)
            loss = F.mse_loss(y_pred, y_star)
            loss.backward()
            optimizer.step()
            
            batch_loss = loss.item()
            epoch_loss += batch_loss
            running_loss += batch_loss
            global_step += 1
            
            # Step-based 學習率更新
            if use_linear_decay and scheduler is not None and use_step_logging:
                scheduler.step()
                current_lr = scheduler.get_last_lr()[0]
            
            # Step-based logging
            if use_step_logging:
                # Regular step logging
                if global_step % log_interval_steps == 0:
                    avg_loss = running_loss / log_interval_steps
                    current_time = time.time()
                    time_per_step = (current_time - last_log_time) / log_interval_steps
                    steps_per_sec = 1.0 / time_per_step
                    
                    print(f"[{model_name}] Step {global_step} | "
                          f"Epoch {epoch}/{epochs} ({batch_idx+1}/{len(train_loader)}) | "
                          f"LR: {current_lr:.6f} | "
                          f"Loss: {avg_loss:.6f} | "
                          f"Speed: {steps_per_sec:.1f} steps/s | "
                          f"Time/step: {time_per_step:.3f}s")
                    
                    running_loss = 0.0
                    last_log_time = current_time
                
                # Step-based eval metrics calculation
                if global_step % eval_interval_steps == 0:
                    # Record training loss at eval intervals
                    current_train_loss = running_loss / log_interval_steps  # 使用最近的平均 loss
                    train_losses.append(current_train_loss)
                    
                    # Calculate eval loss ONLY (skip detailed metrics for speed)
                    eval_losses_step = calculate_eval_loss_all_loaders(model, eval_loaders, device)
                    eval_losses_all.append(eval_losses_step)
                    
                    # Track best model
                    primary_eval_loss = eval_losses_step[primary_eval_idx]
                    if primary_eval_loss < best_eval_loss:
                        best_eval_loss = primary_eval_loss
                        best_epoch = epoch
                        best_step = global_step
                        best_model = copy.deepcopy(model)
                        print(f"      NEW BEST at Step {global_step}! Eval loss: {primary_eval_loss:.6f}")
                    
                    # Store nan values for detailed metrics (to maintain consistent data structure)
                    num_eval_loaders = len(eval_loaders)
                    detailed_metrics['psnr'].append([float('nan')] * num_eval_loaders)
                    detailed_metrics['ssim'].append([float('nan')] * num_eval_loaders)
                    detailed_metrics['psnr_batch'].append([float('nan')] * num_eval_loaders)
                    detailed_metrics['ssim_batch'].append([float('nan')] * num_eval_loaders)
                    
                    if track_per_class:
                        for class_id in range(num_classes):
                            per_class_metrics['psnr'][class_id].append([float('nan')] * num_eval_loaders)
                            per_class_metrics['ssim'][class_id].append([float('nan')] * num_eval_loaders)
                            per_class_metrics['psnr_batch'][class_id].append([float('nan')] * num_eval_loaders)
                            per_class_metrics['ssim_batch'][class_id].append([float('nan')] * num_eval_loaders)
                    
                    # Display eval losses only
                    eval_losses_str = " | ".join([f"Eval{i}: {loss:.4f}" for i, loss in enumerate(eval_losses_step)])
                    
                    print(f"      Step {global_step} EVAL | {eval_losses_str} | "
                          f"PSNR: ---- | SSIM: ---- | "
                          f"Params: {trainable_params/1e6:.2f}M/{total_params/1e6:.2f}M")
                    
                    model.train()  # Return to train mode
                
                # Step-based sample saving
                if global_step % save_interval_steps == 0:
                    model.eval()
                    with torch.no_grad():
                        y_pred_vis = model(x_vis, y_star_vis)
                        save_path = os.path.join(samples_dir, f"step{global_step}.png")
                        save_samples_with_border(x_vis[:8].cpu(), y_pred_vis[:8].cpu(), y_star_vis[:8].cpu(), 
                                               img_shape, save_path, mode, down_scale)
                    model.train()
        
        # ===== Epoch-based 學習率更新（在 epoch 結束時） =====
        if use_linear_decay and scheduler is not None and not use_step_logging:
            scheduler.step()
            # 注意：這裡不需要重新獲取 current_lr，因為我們在 epoch 開始時已經獲取了
        
        # End-of-epoch evaluation for step-based logging
        if use_step_logging:
            # Force evaluation at the end of each epoch (if not already done)
            if global_step % eval_interval_steps != 0:
                eval_losses_step = calculate_eval_loss_all_loaders(model, eval_loaders, device)
                eval_losses_all.append(eval_losses_step)
                
                # Track best model
                primary_eval_loss = eval_losses_step[primary_eval_idx]
                if primary_eval_loss < best_eval_loss:
                    best_eval_loss = primary_eval_loss
                    best_epoch = epoch
                    best_step = global_step
                    best_model = copy.deepcopy(model)
                    print(f"      NEW BEST at Step {global_step} (end of epoch)! Eval loss: {primary_eval_loss:.6f}")
                
                # Store nan values for detailed metrics
                num_eval_loaders = len(eval_loaders)
                detailed_metrics['psnr'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['ssim'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['psnr_batch'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['ssim_batch'].append([float('nan')] * num_eval_loaders)
                
                if track_per_class:
                    for class_id in range(num_classes):
                        per_class_metrics['psnr'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['ssim'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['psnr_batch'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['ssim_batch'][class_id].append([float('nan')] * num_eval_loaders)
                
                eval_losses_str = " | ".join([f"Eval{i}: {loss:.4f}" for i, loss in enumerate(eval_losses_step)])
                print(f"      Step {global_step} END-EPOCH-EVAL | {eval_losses_str} | "
                      f"PSNR: ---- | SSIM: ---- | "
                      f"Params: {trainable_params/1e6:.2f}M/{total_params/1e6:.2f}M")
        
        # Epoch-based logging (原版本邏輯)
        if not use_step_logging:
            epoch_loss /= len(train_loader)
            train_losses.append(epoch_loss)
            
            # Evaluation phase - calculate eval loss on all eval loaders
            if epoch % 5 == 0 or epoch == epochs:
                eval_losses_epoch = calculate_eval_loss_all_loaders(model, eval_loaders, device)
            else:
                eval_losses_epoch = [np.inf] * len(eval_loaders)
            eval_losses_all.append(eval_losses_epoch)
                
            # Track best model based on primary eval loader
            primary_eval_loss = eval_losses_epoch[primary_eval_idx]
            if primary_eval_loss < best_eval_loss:
                best_eval_loss = primary_eval_loss
                best_epoch = epoch
                best_model = copy.deepcopy(model)
                print(f"      New best model at epoch {epoch}! Primary eval loss: {primary_eval_loss:.6f}")
            
            # Detailed evaluation at specified intervals
            if epoch % eval_interval == 0 or epoch == epochs:
                # Calculate detailed metrics for ALL eval loaders
                all_overall_metrics = []
                all_class_metrics = []
                
                for eval_idx, eval_loader in enumerate(eval_loaders):
                    # Calculate overall metrics for this eval loader
                    overall_metrics = evaluate_overall_metrics(model, eval_loader, device, img_shape)
                    all_overall_metrics.append(overall_metrics)
                    
                    # Calculate per-class metrics if needed
                    class_metrics = None
                    if track_per_class:
                        class_metrics = evaluate_per_class_metrics(model, eval_loader, device, img_shape, num_classes)
                    all_class_metrics.append(class_metrics)
                
                # Store metrics for all eval loaders
                detailed_metrics['psnr'].append([m['psnr'] for m in all_overall_metrics])
                detailed_metrics['ssim'].append([m['ssim'] for m in all_overall_metrics])
                detailed_metrics['psnr_batch'].append([m['psnr_batch'] for m in all_overall_metrics])
                detailed_metrics['ssim_batch'].append([m['ssim_batch'] for m in all_overall_metrics])
                
                # Store per-class metrics if needed
                if track_per_class:
                    for class_id in range(num_classes):
                        per_class_metrics['psnr'][class_id].append([
                            cm[class_id]['psnr'] for cm in all_class_metrics])
                        per_class_metrics['ssim'][class_id].append([
                            cm[class_id]['ssim'] for cm in all_class_metrics])
                        per_class_metrics['psnr_batch'][class_id].append([
                            cm[class_id]['psnr_batch'] for cm in all_class_metrics])
                        per_class_metrics['ssim_batch'][class_id].append([
                            cm[class_id]['ssim_batch'] for cm in all_class_metrics])
                
                # Create visualization using primary eval loader
                model.eval()
                with torch.no_grad():
                    y_pred_vis = model(x_vis, y_star_vis)
                    save_path = os.path.join(samples_dir, f"epoch{epoch}.png")
                    save_samples_with_border(x_vis[:8].cpu(), y_pred_vis[:8].cpu(), y_star_vis[:8].cpu(), 
                                           img_shape, save_path, mode, down_scale)
                
                # Display values for primary eval loader
                primary_metrics = all_overall_metrics[primary_eval_idx]
                psnr_val = primary_metrics['psnr']
                ssim_val = primary_metrics['ssim']
                psnr_batch_val = primary_metrics['psnr_batch']
                ssim_batch_val = primary_metrics['ssim_batch']
            else:
                # Non-evaluation epochs: store nan for detailed metrics for all eval loaders
                num_eval_loaders = len(eval_loaders)
                detailed_metrics['psnr'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['ssim'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['psnr_batch'].append([float('nan')] * num_eval_loaders)
                detailed_metrics['ssim_batch'].append([float('nan')] * num_eval_loaders)
                
                if track_per_class:
                    for class_id in range(num_classes):
                        per_class_metrics['psnr'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['ssim'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['psnr_batch'][class_id].append([float('nan')] * num_eval_loaders)
                        per_class_metrics['ssim_batch'][class_id].append([float('nan')] * num_eval_loaders)
                
                # Set display values to nan
                psnr_val = ssim_val = psnr_batch_val = ssim_batch_val = float('nan')
            
            # Display training progress
            eval_losses_str = " | ".join([f"Eval{i}: {loss:.4f}" for i, loss in enumerate(eval_losses_epoch)])
            
            # Format display values properly
            psnr_str = f"{psnr_val:.4f}" if not np.isnan(psnr_val) else "----"
            psnr_batch_str = f"{psnr_batch_val:.4f}" if not np.isnan(psnr_batch_val) else "----"
            ssim_str = f"{ssim_val:.4f}" if not np.isnan(ssim_val) else "----"
            ssim_batch_str = f"{ssim_batch_val:.4f}" if not np.isnan(ssim_batch_val) else "----"
            
            print(
                f"[{model_name}] Epoch {epoch}/{epochs} | "
                f"LR: {current_lr:.6f} | "
                f"Train: {epoch_loss:.4f} | "
                f"{eval_losses_str} | "
                f"PSNR: {psnr_str} | "
                f"PSNR_batch: {psnr_batch_str} | "
                f"SSIM: {ssim_str} | "
                f"SSIM_batch: {ssim_batch_str} | "
                f"Params: {trainable_params/1e6:.2f}M/{total_params/1e6:.2f}M "
                f"({100*trainable_params/total_params:.1f}%)"
            )
    
    # Final testing using best model
    if use_step_logging:
        print(f"\n[INFO] Using best model (Step {best_step}, Epoch {best_epoch}, Primary eval loss: {best_eval_loss:.6f}) for final testing")
    else:
        print(f"\n[INFO] Using best model (Epoch {best_epoch}, Primary eval loss: {best_eval_loss:.6f}) for final testing")
    
    all_test_results = final_test_evaluation(
        best_model, test_loaders, device, img_shape, track_per_class, num_classes)
    
    # Create final test visualization using best model
    if len(test_loaders) > 0:
        batch_data = next(iter(test_loaders[0]))
        x_vis, y_star_vis = batch_data[0], batch_data[1]
        x_vis, y_star_vis = x_vis.to(device), y_star_vis.to(device)
        y_pred_vis = best_model(x_vis, y_star_vis)
        save_path = os.path.join(samples_dir, "final_test_best_model.png")
        save_samples_with_border(x_vis[:8].cpu(), y_pred_vis[:8].cpu(), y_star_vis[:8].cpu(), 
                               img_shape, save_path, mode, down_scale)
    
    # Prepare return values
    training_metrics = {
        'train_losses': train_losses,
        'eval_losses_all': eval_losses_all,  # [epoch][eval_idx] or [step][eval_idx]
        'detailed_metrics': detailed_metrics,
        'per_class_metrics': per_class_metrics,  # None if track_per_class=False
        'best_epoch': best_epoch,
        'best_step': best_step if use_step_logging else None,
        'best_eval_loss': best_eval_loss
    }
    
    return training_metrics, all_test_results