
import os
import os.path as osp
import random
import time
import copy
from typing import List, Dict, Any, Optional
import numpy as np
import PIL.Image as PImage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import lpips
import wandb

from inference_VQ_Diffusion import VQ_Diffusion


class VQVAEFinetuningConfig:
    """Configuration class for VQVAE finetuning with clearly documented parameters.
    
    Edit these parameters directly in the class to configure your finetuning run.
    """
    
    # Data Configuration
    DATASET_ROOT = "/workspace/VQ-Diffusion/generated_images/ithq_np10000_pr2_seed5"  # Update this path
    IMAGE_SIZE = 256
    BATCH_SIZE = 16
    NUM_WORKERS = 8
    
    # Augmentation Configuration
    USE_AUGMENTATIONS = False           # Enable data augmentations
    HORIZONTAL_FLIP_PROB = 0.5         # Probability of horizontal flip
    ROTATION_DEGREES = 15              # Max rotation degrees (+/-)
    COLOR_JITTER_BRIGHTNESS = 0.2      # Brightness jitter strength
    COLOR_JITTER_CONTRAST = 0.2        # Contrast jitter strength
    COLOR_JITTER_SATURATION = 0.2      # Saturation jitter strength
    COLOR_JITTER_HUE = 0.1             # Hue jitter strength
    GAUSSIAN_BLUR_PROB = 0.2           # Probability of Gaussian blur
    GAUSSIAN_BLUR_SIGMA = (0.1, 2.0)   # Gaussian blur sigma range
    
    # Training Configuration
    NUM_EPOCHS = 50                    # Number of training epochs
    LEARNING_RATE = 5e-5               # Learning rate
    WEIGHT_DECAY = 1e-4
    
    # Optimizer Configuration
    OPTIMIZER = "Adam"                 # Optimizer type
    
    # Scheduler Configuration  
    SCHEDULER = "StepLR"               # Scheduler type
    STEP_LR_GAMMA = 0.9                # StepLR gamma (decay factor)
    STEP_LR_STEP_SIZE = 2              # StepLR step size (epochs)
    
    # Model Configuration
    VQ_DIFFUSION_CONFIG = "configs/imagenet.yaml"  # Base config to use
    VQVAE_CKPT_PATH = None  # Path to pretrained VQVAE checkpoint (optional)
    
    # Training Components - Set which parts to train
    TRAIN_ENCODER = True
    TRAIN_DECODER = False  
    TRAIN_QUANTIZER = False
    
    # Loss Configuration
    MSE_FEAT_WEIGHT = 1.0      # Feature matching loss weight
    MSE_IMG_WEIGHT = 0.0       # Image reconstruction loss weight
    LPIPS_WEIGHT = 0.0         # Perceptual loss weight
    
    # Environment Configuration
    GPU_ID = 0
    SEED = 0
    
    # Logging Configuration
    LOG_IMAGE_FREQ = 100       # Log images every N steps within an epoch
    LOG_METRICS_FREQ = 10      # Log scalar metrics every N steps (set to 1 for every step)
    ASYNC_LOGGING = True       # Use asynchronous logging for better performance
    WANDB_PROJECT = "vqvae-finetuning-vqdiffusion"
    EXPERIMENT_NAME = ""       # Auto-generated if empty
    
    # Output Configuration
    OUTPUT_DIR = f"./finetuned_models_BS{BATCH_SIZE}_L{LEARNING_RATE}_W{WEIGHT_DECAY}_E{NUM_EPOCHS}_SEED{SEED}"
    SAVE_EVERY_N_EPOCHS = 2    # Save checkpoint every N epochs


class ImageDatasetForVQVAE(Dataset):
    """Dataset class for loading images for VQVAE finetuning.
    
    Uses the same preprocessing approach as batch_loss_analysis.py for consistency.
    """
    
    def __init__(self, data_root: str, config: VQVAEFinetuningConfig, phase: str = 'train'):
        self.data_root = data_root
        self.config = config
        self.image_size = config.IMAGE_SIZE
        self.phase = phase
        
        # Create base transform (same as batch_loss_analysis.py)
        self.base_transform = transforms.Compose([
            transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Maps [0,1] to [-1,1]
        ])
        
        # Setup augmentations
        self.use_augmentations = config.USE_AUGMENTATIONS and phase == 'train'
        if self.use_augmentations:
            self.augmentation_transform = self._create_augmentation_transform()
        
        # Load image paths (same pattern as batch_loss_analysis.py)
        self.image_paths = self._load_image_paths()
        
        # Load ground truth tokens (like in the reference script)
        self.load_ground_truth_tokens()
        
    def _load_image_paths(self) -> List[str]:
        """Load all image paths from the dataset directory (same as batch_loss_analysis.py)."""
        image_paths = sorted([
            os.path.join(self.data_root, f) for f in os.listdir(self.data_root)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ])
        
        if not image_paths:
            raise FileNotFoundError(f"No images found in {self.data_root}")
            
        print(f"Found {len(image_paths)} images in {self.data_root}")
        return image_paths
    
    def load_ground_truth_tokens(self):
        """Load ground truth tokens from the dataset (REQUIRED - from generate_images.py output)."""
        # Load tokens saved by generate_images.py
        token_file = os.path.join(self.data_root, 'all_tokens.pt')
        if not os.path.exists(token_file):
            raise FileNotFoundError(f"Ground truth tokens are REQUIRED but not found at {token_file}. "
                                  f"Please run generate_images.py first to create the tokens.")
        
        print(f"Loading ground truth tokens from {token_file}")
        # Load tokens but keep on CPU for memory efficiency
        self.ground_truth_tokens = torch.load(token_file, map_location='cpu', weights_only=False)
        print(f"Loaded ground truth tokens with shape: {self.ground_truth_tokens.shape}")
        print(f"Number of token sets: {self.ground_truth_tokens.shape[0]}")
    
    def _create_augmentation_transform(self):
        """Create augmentation transform based on configuration."""
        augmentations = []
        
        # Horizontal flip
        if self.config.HORIZONTAL_FLIP_PROB > 0:
            augmentations.append(
                transforms.RandomHorizontalFlip(p=self.config.HORIZONTAL_FLIP_PROB)
            )
        
        # Rotation
        if self.config.ROTATION_DEGREES > 0:
            augmentations.append(
                transforms.RandomRotation(
                    degrees=self.config.ROTATION_DEGREES,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                    fill=0
                )
            )
        
        # Color jitter
        if (self.config.COLOR_JITTER_BRIGHTNESS > 0 or 
            self.config.COLOR_JITTER_CONTRAST > 0 or
            self.config.COLOR_JITTER_SATURATION > 0 or
            self.config.COLOR_JITTER_HUE > 0):
            augmentations.append(
                transforms.ColorJitter(
                    brightness=self.config.COLOR_JITTER_BRIGHTNESS,
                    contrast=self.config.COLOR_JITTER_CONTRAST,
                    saturation=self.config.COLOR_JITTER_SATURATION,
                    hue=self.config.COLOR_JITTER_HUE
                )
            )
        
        # Gaussian blur
        if self.config.GAUSSIAN_BLUR_PROB > 0:
            augmentations.append(
                transforms.RandomApply([
                    transforms.GaussianBlur(
                        kernel_size=5,
                        sigma=self.config.GAUSSIAN_BLUR_SIGMA
                    )
                ], p=self.config.GAUSSIAN_BLUR_PROB)
            )
        
        # Additional augmentations for robustness
        augmentations.extend([
            # Random perspective (subtle)
            transforms.RandomPerspective(distortion_scale=0.1, p=0.2),
            
            # Random affine transformation (subtle)
            transforms.RandomAffine(
                degrees=0,
                translate=(0.05, 0.05),
                scale=(0.95, 1.05),
                shear=5,
                interpolation=transforms.InterpolationMode.BILINEAR,
                fill=0
            ),
        ])
        
        return transforms.Compose(augmentations)
    
    def __len__(self) -> int:
        return len(self.image_paths)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Load and preprocess an image with optional augmentations and ground truth tokens."""
        try:
            image_path = self.image_paths[idx]
            image = PImage.open(image_path).convert('RGB')
            
            # Apply augmentations before base transform if enabled
            if self.use_augmentations:
                # Convert PIL to tensor for augmentations
                image_tensor = transforms.ToTensor()(image)
                # Apply augmentations
                image_tensor = self.augmentation_transform(image_tensor)
                # Convert back to PIL for base transform
                image = transforms.ToPILImage()(image_tensor)
            
            # Apply base transform (same as batch_loss_analysis.py)
            image_tensor = self.base_transform(image)
            
            result = {
                'image': image_tensor,
                'path': image_path
            }
            
            # Add ground truth tokens (REQUIRED)
            if idx >= len(self.ground_truth_tokens):
                raise IndexError(f"Token index {idx} out of range. Only {len(self.ground_truth_tokens)} token sets available.")
            result['ground_truth_tokens'] = self.ground_truth_tokens[idx].clone()
            
            return result
            
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            # Return a random other image on error
            return self.__getitem__(random.randint(0, len(self) - 1))


def setup_environment(config: VQVAEFinetuningConfig) -> torch.device:
    """Configure environment for the script."""
    random.seed(config.SEED)
    np.random.seed(config.SEED)
    torch.manual_seed(config.SEED)
    
    if torch.cuda.is_available():
        torch.cuda.set_device(config.GPU_ID)
        device = torch.device(f"cuda:{config.GPU_ID}")
        torch.cuda.manual_seed(config.SEED)
        torch.cuda.manual_seed_all(config.SEED)
    else:
        device = torch.device("cpu")
        
    print(f"Using device: {device}")
    return device


def initialize_vq_codec(config: VQVAEFinetuningConfig):
    """Initialize VQ codec exactly like batch_loss_analysis.py"""
    vq_model = VQ_Diffusion(
        config='/workspace/VQ-Diffusion/configs/ithq.yaml',
        path='/checkpoints/pretrained_model/ithq_learnable.pth'
    )
    return vq_model.model.content_codec


def setup_trainable_parameters(vqvae, config: VQVAEFinetuningConfig) -> List[torch.nn.Parameter]:
    """Configure which parts of the VQVAE to train using exact component names from the model."""
    
    # Debug: Print model structure
    print("VQVAE model structure:")
    for name, module in vqvae.named_children():
        print(f"  - {name}: {type(module)}")
    
    # First, freeze everything
    for param in vqvae.parameters():
        param.requires_grad_(False)
    
    trainable_params = []
    
    if config.TRAIN_ENCODER:
        if hasattr(vqvae, 'enc'):
            print("Unfreezing ENC (encoder).")
            for param in vqvae.enc.parameters():
                param.requires_grad_(True)
            trainable_params.extend(vqvae.enc.parameters())
        else:
            print("Warning: TRAIN_ENCODER=True but 'enc' component not found!")
    
    if config.TRAIN_DECODER:
        if hasattr(vqvae, 'dec'):
            print("Unfreezing DEC (decoder).")
            for param in vqvae.dec.parameters():
                param.requires_grad_(True)
            trainable_params.extend(vqvae.dec.parameters())
        else:
            print("Warning: TRAIN_DECODER=True but 'dec' component not found!")
    
    if config.TRAIN_QUANTIZER:
        if hasattr(vqvae, 'quantize'):
            print("Unfreezing QUANTIZE.")
            for param in vqvae.quantize.parameters():
                param.requires_grad_(True)
            trainable_params.extend(vqvae.quantize.parameters())
        else:
            print("Warning: TRAIN_QUANTIZER=True but 'quantize' component not found!")
    
    if not trainable_params:
        print("\nAvailable model components:")
        for name, module in vqvae.named_children():
            print(f"  - {name}: {type(module)}")
        raise ValueError("No trainable parameters found! Available components shown above.")
    
    print(f"Total trainable parameters: {sum(p.numel() for p in trainable_params):,}")
    return trainable_params


def compute_losses(vqvae, batch: Dict[str, torch.Tensor], config: VQVAEFinetuningConfig, 
                  lpips_loss_fn, device: torch.device) -> Dict[str, torch.Tensor]:
    """Compute all loss components using ground truth tokens when available."""
    
    images = batch['image'].to(device)
    
    # Always do forward pass through encoder
    h = vqvae.enc.encoder(images)
    f_continuous = vqvae.enc.quant_conv(h)
    indices = vqvae.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
    z_quantized = vqvae.enc.quantize.get_codebook_entry(indices.view(-1), shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1]))
    
    # Use ground truth tokens as target (REQUIRED)
    if 'ground_truth_tokens' not in batch:
        raise ValueError("Ground truth tokens are REQUIRED but not found in batch. "
                        "Make sure the dataset has tokens loaded.")
    
    ground_truth_tokens = batch['ground_truth_tokens'].to(device)
    
    # Get ground truth quantized features from tokens
    with torch.no_grad():
        # Convert tokens to quantized features (target)
        # The tokens should have the right shape for get_codebook_entry
        if ground_truth_tokens.dim() == 3:  # [batch, H, W]
            z_target = vqvae.enc.quantize.get_codebook_entry(
                ground_truth_tokens.view(-1), 
                shape=(images.shape[0], ground_truth_tokens.shape[-2], ground_truth_tokens.shape[-1])
            )
        else:  # Assume flattened tokens
            # Try to infer spatial dimensions from f_continuous
            spatial_size = int((ground_truth_tokens.shape[-1]) ** 0.5)
            z_target = vqvae.enc.quantize.get_codebook_entry(
                ground_truth_tokens.view(-1), 
                shape=(images.shape[0], spatial_size, spatial_size)
            )
            
    
    # Decode to get reconstructed image
    f_quantized = vqvae.dec.post_quant_conv(z_quantized)
    reconstructed = vqvae.dec.decoder(f_quantized) 
    reconstructed = torch.clamp(reconstructed, -1., 1.)
    
    # Compute losses
    losses = {}
    
    # Image reconstruction loss (MSE in pixel space)
    losses['mse_img'] = F.mse_loss(reconstructed, images)
    
    # Feature matching loss (compare predicted features with ground truth target)
    losses['mse_feat'] = F.mse_loss(f_continuous, z_target)
    
    # Perceptual loss (LPIPS)
    # Convert from [-1,1] to [0,1] for LPIPS
    img_for_lpips = (images + 1) / 2
    rec_for_lpips = (reconstructed + 1) / 2
    losses['lpips'] = lpips_loss_fn(rec_for_lpips, img_for_lpips).mean()
    
    # Total loss (only the losses you specified)
    total_loss = (
        config.MSE_IMG_WEIGHT * losses['mse_img'] +
        config.MSE_FEAT_WEIGHT * losses['mse_feat'] +
        config.LPIPS_WEIGHT * losses['lpips']
    )
    
    losses['total'] = total_loss
    losses['reconstructed'] = reconstructed  # For logging
    
    return losses


def main():
    """Main function to run the VQVAE finetuning script."""
    
    # Initialize configuration
    config = VQVAEFinetuningConfig()
    
    # Generate experiment name if not provided
    if not config.EXPERIMENT_NAME:
        timestamp = time.strftime('%Y%m%d_%H%M%S')
        components = []
        if config.TRAIN_ENCODER:
            components.append("enc")
        if config.TRAIN_DECODER:
            components.append("dec")
        if config.TRAIN_QUANTIZER:
            components.append("quant")
        config.EXPERIMENT_NAME = f"vqvae_finetune_{timestamp}_{'_'.join(components)}_epochs{config.NUM_EPOCHS}"
    
    # Setup environment
    device = setup_environment(config)
    
    # Initialize Wandb
    wandb_config = {
        'dataset_root': config.DATASET_ROOT,
        'image_size': config.IMAGE_SIZE,
        'batch_size': config.BATCH_SIZE,
        'num_epochs': config.NUM_EPOCHS,
        'learning_rate': config.LEARNING_RATE,
        'weight_decay': config.WEIGHT_DECAY,
        'optimizer': config.OPTIMIZER,
        'scheduler': config.SCHEDULER,
        'step_lr_gamma': config.STEP_LR_GAMMA if config.SCHEDULER == "StepLR" else None,
        'step_lr_step_size': config.STEP_LR_STEP_SIZE if config.SCHEDULER == "StepLR" else None,
        'use_augmentations': config.USE_AUGMENTATIONS,
        'train_encoder': config.TRAIN_ENCODER,
        'train_decoder': config.TRAIN_DECODER,
        'train_quantizer': config.TRAIN_QUANTIZER,
        'mse_feat_weight': config.MSE_FEAT_WEIGHT,
        'mse_img_weight': config.MSE_IMG_WEIGHT,
        'lpips_weight': config.LPIPS_WEIGHT,
        'log_metrics_freq': config.LOG_METRICS_FREQ,
        'log_image_freq': config.LOG_IMAGE_FREQ,
        'async_logging': config.ASYNC_LOGGING,
    }
    
    wandb.init(
        project=config.WANDB_PROJECT,
        name=config.EXPERIMENT_NAME,
        config=wandb_config
    )
    
    print(f"Starting experiment: {config.EXPERIMENT_NAME}")
    
    # Initialize codec exactly like batch_loss_analysis.py
    print("Initializing VQ codec...")
    vqvae = initialize_vq_codec(config)
    vqvae = vqvae.to(device)
    vqvae.eval()
    
    # Setup trainable parameters
    trainable_params = setup_trainable_parameters(vqvae, config)
    
    # Watch model in wandb
    wandb.watch(vqvae, log="all", log_freq=100)
    
    # Setup dataset and dataloader
    print("Setting up dataset...")
    dataset = ImageDatasetForVQVAE(
        data_root=config.DATASET_ROOT,
        config=config,
        phase='train'
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        drop_last=True,
        pin_memory=True
    )
    
    # Setup optimizer and scheduler
    print("Setting up optimizer...")
    if config.OPTIMIZER == "Adam":
        optimizer = torch.optim.Adam(
            trainable_params,
            lr=config.LEARNING_RATE,
            weight_decay=config.WEIGHT_DECAY
        )
    elif config.OPTIMIZER == "AdamW":
        optimizer = torch.optim.AdamW(
            trainable_params,
            lr=config.LEARNING_RATE,
            weight_decay=config.WEIGHT_DECAY
        )
    else:
        raise ValueError(f"Unsupported optimizer: {config.OPTIMIZER}")
    
    # Calculate steps per epoch for reference
    steps_per_epoch = len(dataloader)
    total_steps = config.NUM_EPOCHS * steps_per_epoch
    
    # Setup scheduler
    print(f"Setting up {config.SCHEDULER} scheduler...")
    if config.SCHEDULER == "StepLR":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, 
            step_size=config.STEP_LR_STEP_SIZE,
            gamma=config.STEP_LR_GAMMA
        )
        scheduler_step_on = "epoch"  # StepLR steps per epoch
    elif config.SCHEDULER == "CosineAnnealingLR":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=total_steps
        )
        scheduler_step_on = "step"  # CosineAnnealingLR steps per batch
    else:
        raise ValueError(f"Unsupported scheduler: {config.SCHEDULER}")
    
    print(f"Optimizer: {config.OPTIMIZER}, LR: {config.LEARNING_RATE}")
    print(f"Scheduler: {config.SCHEDULER}, Steps on: {scheduler_step_on}")
    if config.SCHEDULER == "StepLR":
        print(f"  - Step size: {config.STEP_LR_STEP_SIZE} epochs")
        print(f"  - Gamma: {config.STEP_LR_GAMMA}")
    
    # Setup loss functions
    lpips_loss_fn = lpips.LPIPS(net="vgg").to(device)
    
    # Create output directory
    os.makedirs(config.OUTPUT_DIR, exist_ok=True)
    
    # Training loop
    print(f"\nStarting finetuning for {config.NUM_EPOCHS} epochs...")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Total steps: {total_steps}")
    
    vqvae.train()
    global_step = 0
    
    # Initialize timing
    training_start_time = time.time()
    epoch_times = []
    
    for epoch in range(config.NUM_EPOCHS):
        epoch_start_time = time.time()
        epoch_losses = {'total': 0, 'mse_img': 0, 'mse_feat': 0, 'lpips': 0}
        
        # Create progress bar for this epoch
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")
        
        for batch_idx, batch in enumerate(progress_bar):
            optimizer.zero_grad()
            
            # Compute losses
            losses = compute_losses(vqvae, batch, config, lpips_loss_fn, device)
            
            # Backward pass
            losses['total'].backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
            
            optimizer.step()
            
            # Step scheduler if it's step-based (like CosineAnnealingLR)
            if scheduler_step_on == "step":
                scheduler.step()
            
            # Accumulate epoch losses
            for key in epoch_losses:
                epoch_losses[key] += losses[key].item()
            
            # Log metrics to wandb (controlled frequency for performance)
            if global_step % config.LOG_METRICS_FREQ == 0:
                log_dict = {
                    'train/total_loss': losses['total'].item(),
                    'train/mse_img_loss': losses['mse_img'].item(),
                    'train/mse_feat_loss': losses['mse_feat'].item(),
                    'train/lpips_loss': losses['lpips'].item(),
                    'train/learning_rate': scheduler.get_last_lr()[0],
                    'train/epoch': epoch + 1,
                }
                
                # Use asynchronous logging if enabled for better performance
                if config.ASYNC_LOGGING:
                    wandb.log(log_dict, step=global_step, commit=False)
                else:
                    wandb.log(log_dict, step=global_step)
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{losses['total'].item():.6f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                'step': f"{batch_idx+1}/{steps_per_epoch}"
            })
            
            # Log images periodically (within epoch) - optimized for performance
            if global_step % config.LOG_IMAGE_FREQ == 0:
                with torch.no_grad():
                    # Only process 2 images instead of 4 for faster upload
                    original_images = batch['image'][:2].cpu()
                    reconstructed_images = losses['reconstructed'][:2].cpu()
                    
                    # Convert from [-1,1] to [0,1] for logging
                    original_images = (original_images + 1) / 2
                    reconstructed_images = (reconstructed_images + 1) / 2
                    
                    image_log_dict = {
                        'images/original': [wandb.Image(img) for img in original_images],
                        'images/reconstructed': [wandb.Image(img) for img in reconstructed_images],
                    }
                    
                    # Use asynchronous logging for images too
                    if config.ASYNC_LOGGING:
                        wandb.log(image_log_dict, step=global_step, commit=False)
                    else:
                        wandb.log(image_log_dict, step=global_step)
            
            global_step += 1
        
        # Calculate epoch averages
        for key in epoch_losses:
            epoch_losses[key] /= len(dataloader)
        
        # Calculate epoch timing
        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        epoch_times.append(epoch_duration)
        
        # Calculate timing statistics
        avg_epoch_time = sum(epoch_times) / len(epoch_times)
        total_elapsed = epoch_end_time - training_start_time
        remaining_epochs = config.NUM_EPOCHS - (epoch + 1)
        estimated_remaining_time = remaining_epochs * avg_epoch_time
        
        # Format time strings
        def format_time(seconds):
            hours = int(seconds // 3600)
            minutes = int((seconds % 3600) // 60)
            secs = int(seconds % 60)
            if hours > 0:
                return f"{hours}h {minutes}m {secs}s"
            elif minutes > 0:
                return f"{minutes}m {secs}s"
            else:
                return f"{secs}s"
        
        # Log epoch summary with timing
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average Total Loss: {epoch_losses['total']:.6f}")
        print(f"  Average MSE Image Loss: {epoch_losses['mse_img']:.6f}")
        print(f"  Average MSE Feature Loss: {epoch_losses['mse_feat']:.6f}")
        print(f"  Average LPIPS Loss: {epoch_losses['lpips']:.6f}")
        print(f"  Epoch Duration: {format_time(epoch_duration)}")
        print(f"  Total Elapsed: {format_time(total_elapsed)}")
        if remaining_epochs > 0:
            print(f"  Estimated Remaining: {format_time(estimated_remaining_time)}")
            print(f"  Estimated Total: {format_time(total_elapsed + estimated_remaining_time)}")
        
        # Log epoch averages and timing to wandb
        epoch_summary_dict = {
            'epoch/avg_total_loss': epoch_losses['total'],
            'epoch/avg_mse_img_loss': epoch_losses['mse_img'],
            'epoch/avg_mse_feat_loss': epoch_losses['mse_feat'],
            'epoch/avg_lpips_loss': epoch_losses['lpips'],
            'timing/epoch_duration_seconds': epoch_duration,
            'timing/total_elapsed_seconds': total_elapsed,
            'timing/avg_epoch_time_seconds': avg_epoch_time,
        }
        
        # Always commit epoch summaries to ensure they're synced
        wandb.log(epoch_summary_dict, step=global_step, commit=True)
        
        # Step scheduler if it's epoch-based (like StepLR)
        if scheduler_step_on == "epoch":
            scheduler.step()
            print(f"  Learning rate after epoch {epoch+1}: {scheduler.get_last_lr()[0]:.2e}")
        
        # Save checkpoint periodically (every N epochs)
        if (epoch + 1) % config.SAVE_EVERY_N_EPOCHS == 0:
            checkpoint_path = osp.join(
                config.OUTPUT_DIR, 
                f"{config.EXPERIMENT_NAME}_epoch_{epoch+1}.pth"
            )
            torch.save({
                'epoch': epoch + 1,
                'global_step': global_step,
                'model_state_dict': vqvae.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'config': config.__dict__,
                'epoch_losses': epoch_losses
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
    
    # Calculate final training time
    training_end_time = time.time()
    total_training_time = training_end_time - training_start_time
    
    print("Finetuning completed!")
    print(f"\n=== TRAINING TIME SUMMARY ===")
    print(f"Total Training Time: {format_time(total_training_time)}")
    print(f"Average Time per Epoch: {format_time(sum(epoch_times) / len(epoch_times))}")
    print(f"Fastest Epoch: {format_time(min(epoch_times))}")
    print(f"Slowest Epoch: {format_time(max(epoch_times))}")
    print(f"Total Steps: {global_step}")
    print(f"Average Time per Step: {total_training_time / global_step:.3f} seconds")
    print("=" * 30)
    
    # Save final model
    final_model_path = osp.join(config.OUTPUT_DIR, f"{config.EXPERIMENT_NAME}_final.pth")
    torch.save({
        'epoch': config.NUM_EPOCHS,
        'global_step': global_step,
        'model_state_dict': vqvae.state_dict(),
        'config': config.__dict__,
        'training_time_seconds': total_training_time,
        'epoch_times': epoch_times,
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Log final timing to wandb
    wandb.log({
        'timing/total_training_time_seconds': total_training_time,
        'timing/avg_time_per_epoch_seconds': sum(epoch_times) / len(epoch_times),
        'timing/fastest_epoch_seconds': min(epoch_times),
        'timing/slowest_epoch_seconds': max(epoch_times),
        'timing/avg_time_per_step_seconds': total_training_time / global_step,
    })
    
    # Finish wandb run
    wandb.finish()


if __name__ == "__main__":
    main()
