

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

"""
Train Multimodal Adaptive AE on NYU Depth V2 dataset using pretrained RGB and Depth VAEs.

This script:
1. Loads pretrained RGB and Depth VAE checkpoints from stage 2 pretraining
2. Initializes multimodal model with frozen VAEs
3. Trains fusion layers and adaptive rank-reduced bottlenecks
4. Saves final representations and model checkpoints
"""

import torch
import torch.nn as nn
import sys
import os
import argparse
from torch.utils.data import DataLoader
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import h5py
import requests
from tqdm import tqdm
import glob
from PIL import Image

# Add project root to path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.nyudepthv2_loader import NYUDepthV2Dataset
from src.models.larrp_image_depth import MultimodalAdaptiveAE_ImageDepth
from src.functions.train_larrp_multimodal_imagedepth import train_continuous_multimodal_ae

# =======================================================
# DATASET DOWNLOAD UTILITY
# =======================================================

def download_and_extract_nyu(data_root):
    """Download and extract NYU Depth V2 dataset if not already present."""
    url = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
    filename = "nyu_depth_v2_labeled.mat"
    filepath = os.path.join(data_root, filename)

    # Output directories
    img_dir = os.path.join(data_root, "images")
    depth_dir = os.path.join(data_root, "depths")

    # Check if already extracted
    if os.path.exists(img_dir) and os.path.exists(depth_dir):
        num_images = len([f for f in os.listdir(img_dir) if f.endswith('.png')]) if os.path.exists(img_dir) else 0
        num_depths = len([f for f in os.listdir(depth_dir) if f.endswith('.png')]) if os.path.exists(depth_dir) else 0
        
        if num_images > 1400 and num_depths > 1400:
            print(f"✓ Dataset already extracted: {num_images} images, {num_depths} depth maps")
            return

    # Download
    if not os.path.exists(filepath):
        print(f"File not found. Downloading {filename} (approx. 2.8 GB)...")
        os.makedirs(data_root, exist_ok=True)
        
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filepath, 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True, desc=filename) as bar:
            for chunk in response.iter_content(chunk_size=1024*1024):
                if chunk:
                    f.write(chunk)
                    bar.update(len(chunk))
        print("Download complete.")
    else:
        print(f"Found {filename}, skipping download.")

    # Extract
    print("Extracting Images and Depth Maps from .mat file...")
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(depth_dir, exist_ok=True)

    with h5py.File(filepath, 'r') as f:
        images = f['images']
        depths = f['depths']
        num_samples = images.shape[0]
        
        print(f"Processing {num_samples} samples...")
        
        for i in tqdm(range(num_samples)):
            # Extract RGB
            img_data = images[i]
            img_data = np.transpose(img_data, (2, 1, 0)) 
            img = Image.fromarray(img_data.astype('uint8'))
            img.save(os.path.join(img_dir, f"{i:04d}.png"))
            
            # Extract Depth
            depth_data = depths[i]
            depth_data = np.transpose(depth_data, (1, 0))
            depth_mm = (depth_data * 1000).astype('uint16')
            depth_img = Image.fromarray(depth_mm, mode='I;16')
            depth_img.save(os.path.join(depth_dir, f"{i:04d}_depth.png"))

    print(f"Done! Data located in: {data_root}")

# =======================================================
# VISUALIZATION UTILITIES
# =======================================================

def plot_reconstruction_samples(model, dataloader, device, save_path, n_samples=6):
    """Plot RGB and depth reconstruction samples."""
    model.eval()
    batch = next(iter(dataloader))
    
    rgb = batch['image'][:n_samples].to(device)
    depth = batch['depth'][:n_samples].to(device)
    
    with torch.no_grad():
        recons = model(rgb, depth)
        rgb_recon, depth_recon = recons
    
    # Move to CPU and convert to numpy
    rgb = rgb.cpu().numpy()
    depth = depth.cpu().numpy()
    rgb_recon = rgb_recon.cpu().numpy()
    depth_recon = depth_recon.cpu().numpy()
    
    # Create plot
    fig, axes = plt.subplots(4, n_samples, figsize=(n_samples * 1.5, 6))
    
    for i in range(n_samples):
        # RGB original
        axes[0, i].imshow(np.transpose(rgb[i], (1, 2, 0)))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title("RGB Original", fontsize=10)
        
        # RGB reconstruction
        axes[1, i].imshow(np.transpose(rgb_recon[i], (1, 2, 0)))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title("RGB Recon", fontsize=10)
        
        # Depth original
        axes[2, i].imshow(depth[i, 0], cmap='inferno')
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_title("Depth Original", fontsize=10)
        
        # Depth reconstruction
        axes[3, i].imshow(depth_recon[i, 0], cmap='inferno')
        axes[3, i].axis('off')
        if i == 0:
            axes[3, i].set_title("Depth Recon", fontsize=10)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved reconstruction plot: {save_path}")

# =======================================================
# MAIN TRAINING
# =======================================================

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train Multimodal AE on NYU Depth V2")
    parser.add_argument('--data_root', type=str, default='/ewsc/ewsc/nyu_depth_v2', 
                        help='Path to NYU Depth V2 dataset')
    parser.add_argument('--rgb_checkpoint', type=str, required=True,
                        help='Path to pretrained RGB VAE checkpoint (stage 2)')
    parser.add_argument('--depth_checkpoint', type=str, required=True,
                        help='Path to pretrained Depth VAE checkpoint (stage 2)')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--wd', type=float, default=1e-6)
    parser.add_argument('--latent_dim', type=int, default=500)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save_frequency', type=int, default=10, help='Save checkpoint every N epochs')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--warmup_epochs', type=int, default=50, help='Warmup epochs before rank reduction')
    #parser.add_argument('--rank_reduction_frequency', type=int, default=5)
    #parser.add_argument('--rank_reduction_threshold', type=float, default=0.05,
    #                    help='Threshold for rank reduction decisions')
    parser.add_argument('--mixed_precision', action='store_true',
                        help='Use mixed precision training (BF16) to reduce memory requirements')
    args = parser.parse_args()

    args.model_prefix = f"nyu_depth_v2_multimodal_seed{args.seed}_ld{args.latent_dim}"

    # Setup
    DEVICE = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    save_dir = project_config.NYU_RESULTS_DIR
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs("./03_results/train_plots/nyudepthv2_multimodal", exist_ok=True)

    # --- 0. Check/Download Data ---
    print("\n" + "="*80)
    print("Checking dataset availability...")
    print("="*80)
    download_and_extract_nyu(args.data_root)
    print("="*80 + "\n")

    # --- 1. Load Data ---
    print("Loading NYU Depth V2 dataset...")
    train_dataset = NYUDepthV2Dataset(args.data_root, split='train', size=256)
    val_dataset = NYUDepthV2Dataset(args.data_root, split='val', size=256)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                              num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 
                           num_workers=8, pin_memory=True)
    
    print(f"Train size: {len(train_dataset)} | Val size: {len(val_dataset)}")
    
    # Test batch structure
    print("\n" + "="*80)
    print("Testing batch structure...")
    print("="*80)
    test_batch = next(iter(train_loader))
    print(f"Batch size: {len(test_batch['image'])}")
    print(f"RGB shape: {test_batch['image'].shape}  <- (batch, channels, height, width)")
    print(f"Depth shape: {test_batch['depth'].shape}  <- (batch, channels, height, width)")
    print(f"RGB range: [{test_batch['image'].min():.3f}, {test_batch['image'].max():.3f}]")
    print(f"Depth range: [{test_batch['depth'].min():.3f}, {test_batch['depth'].max():.3f}]")
    print("="*80 + "\n")
    
    # --- 2. Initialize Model ---
    print("Initializing Multimodal Adaptive AE...")
    
    model = MultimodalAdaptiveAE_ImageDepth(
        image_shape=(3, 256, 256),
        depth_shape=(1, 256, 256),
        latent_dim_rgb=args.latent_dim*2,
        latent_dim_depth=args.latent_dim*2,
        latent_dim_shared=args.latent_dim,
        vae_model_name="stabilityai/sd-vae-ft-mse",
        freeze_vae=True  # Keep VAEs frozen, only train projections and fusion
    ).to(DEVICE)
    
    print("\n" + "="*60)
    print("Model initialized with pretrained VAE branches")
    print("="*60 + "\n")
    
    # --- 3. Load Pretrained Checkpoints ---
    print("Loading pretrained checkpoints...")
    
    # Load RGB branch weights
    print(f"Loading RGB checkpoint: {args.rgb_checkpoint}")
    rgb_state = torch.load(args.rgb_checkpoint, map_location=DEVICE)
    # Filter to only load rgb_branch weights
    rgb_branch_state = {k.replace('rgb_branch.', ''): v for k, v in rgb_state.items() if 'rgb_branch' in k}
    if len(rgb_branch_state) == 0:
        # Checkpoint might be from standalone pretraining, load directly
        model.rgb_branch.load_state_dict(rgb_state, strict=False)
        print("  ✓ Loaded RGB branch (standalone checkpoint)")
    else:
        model.rgb_branch.load_state_dict(rgb_branch_state, strict=False)
        print("  ✓ Loaded RGB branch (multimodal checkpoint)")
    
    # Load Depth branch weights
    weights_before_load = model.depth_branch.to_latent.weight.detach().clone().cpu().numpy()
    print(f"Loading Depth checkpoint: {args.depth_checkpoint}")
    depth_state = torch.load(args.depth_checkpoint, map_location=DEVICE)
    # Filter to only load depth_branch weights
    depth_branch_state = {k.replace('depth_branch.', ''): v for k, v in depth_state.items() if 'depth_branch' in k}
    if len(depth_branch_state) == 0:
        # Checkpoint might be from standalone pretraining, load directly
        model.depth_branch.load_state_dict(depth_state, strict=False)
        print("  ✓ Loaded Depth branch (standalone checkpoint)")
    else:
        model.depth_branch.load_state_dict(depth_branch_state, strict=False)
        print("  ✓ Loaded Depth branch (multimodal checkpoint)")
    # check that the projection weights have been changed
    weights_after_load = model.depth_branch.to_latent.weight.detach().clone().cpu().numpy()
    if np.array_equal(weights_before_load, weights_after_load):
        print("  ⚠ Warning: Depth branch projection weights unchanged after loading checkpoint!")
    else:
        print("  ✓ Depth branch projection weights updated from checkpoint.")
    # delete states to free memory
    del weights_after_load, weights_before_load
    
    # --- Freeze pretrained branches (VAE + projections) ---
    print("\nFreezing pretrained branches (VAE + projection layers)...")
    for param in model.rgb_branch.parameters():
        param.requires_grad = False
    for param in model.depth_branch.parameters():
        param.requires_grad = False
    print("  ✓ RGB branch frozen (VAE + projections)")
    print("  ✓ Depth branch frozen (VAE + projections)")
    print("  → Only training fusion layers (shared + modality-specific adaptive layers)")
    
    # Print parameter counts
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    rgb_params = sum(p.numel() for p in model.rgb_branch.parameters())
    depth_params = sum(p.numel() for p in model.depth_branch.parameters())
    
    print(f"\nModel Statistics:")
    print(f"  Total parameters: {total_params/1e6:.2f}M")
    print(f"  Trainable parameters: {trainable_params/1e6:.2f}M")
    print(f"  RGB branch: {rgb_params/1e6:.2f}M")
    print(f"  Depth branch: {depth_params/1e6:.2f}M")
    print(f"  Frozen parameters: {(total_params - trainable_params)/1e6:.2f}M")
    
    # --- 4. Pre-training Visualization ---
    print("\n" + "="*80)
    print("Generating pre-training reconstruction samples...")
    print("="*80)
    
    try:
        plot_reconstruction_samples(
            model, train_loader, DEVICE,
            f"./03_results/train_plots/nyudepthv2_multimodal/{args.model_prefix}_epoch0.png",
            n_samples=6
        )
    except Exception as e:
        print(f"Warning: Could not generate pre-training plots: {e}")
    
    print("="*80 + "\n")
    
    # --- 5. Training with train_continuous_multimodal_ae ---
    print("Starting training with adaptive rank reduction...")
    print(f"Epochs: {args.epochs} | Warmup: {args.warmup_epochs}")
    
    # Prepare args namespace for train_continuous_multimodal_ae
    from types import SimpleNamespace
    train_args = SimpleNamespace(
        multi_gpu=False,
        gpu_ids=str(args.gpu)
    )
    
    # Call the continuous multimodal training function (adapted for image-depth)
    trained_model, reps, final_train_loss, final_rsquare, rank_history_dict, loss_history, sorted_indices = train_continuous_multimodal_ae(
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        device=DEVICE,
        latent_dim=args.latent_dim,
        args=train_args,
        epochs=args.epochs,
        early_stopping=args.epochs,  # No early stopping before completing epochs
        lr=args.lr,
        batch_size=args.batch_size,
        wd=args.wd,
        initial_rank_ratio=1.0,
        min_rank=10,
        rank_schedule=None,  # Auto-generate based on rank_reduction_frequency
        rank_reduction_frequency=5,
        rank_reduction_threshold=0.01,
        warmup_epochs=args.warmup_epochs,
        patience=20,
        reduce_on_best_loss='rsquare',
        r_square_threshold=0.05,
        threshold_type='absolute',
        compressibility_type='direct',
        reduction_criterion='r_squared',
        verbose=True,
        model_name=f"{args.model_prefix}",
        #lr_schedule='cosine',
        lr_schedule=None,
        decision_metric='R2',
        input_shapes=[(3, 256, 256), (1, 256, 256)],
        end_lr=1e-5,
        save_frequency=args.save_frequency,
        modality_keys=['image', 'depth'],  # Specify the modality keys for continuous data
        mixed_precision=args.mixed_precision  # Enable mixed precision if requested
    )
    
    # Update model reference
    model = trained_model
    
    # Extract training history
    train_losses, val_losses = loss_history
    
    # Convert rank_history dict to list format for saving
    rank_history = []
    for i in range(len(rank_history_dict['epoch'])):
        rank_history.append({
            'epoch': rank_history_dict['epoch'][i],
            'total_rank': rank_history_dict['total_rank'][i],
            'loss': rank_history_dict['loss'][i],
            'val_loss': rank_history_dict['val_loss'][i]
        })
    
    print("\n" + "="*80)
    print("Training complete!")
    print("="*80)
    
    # --- 7. Save Final Results ---
    print("\nSaving final results...")
    
    # Save training history
    history_df = pd.DataFrame({
        'epoch': range(len(train_losses)),
        'train_loss': train_losses,
        'val_loss': val_losses
    })
    history_df.to_csv(f"{save_dir}/{args.model_prefix}_training_history.csv", index=False)
    
    # Save rank history
    rank_df = pd.DataFrame(rank_history)
    rank_df.to_csv(f"{save_dir}/{args.model_prefix}_rank_history.csv", index=False)
    
    # Save representations (already computed by train_overcomplete_ae)
    # reps is a list: [shared_rep, rgb_specific_rep, depth_specific_rep]
    # sorted_indices maps representations back to dataset indices
    np.save(f"{save_dir}/{args.model_prefix}_shared_reps.npy", reps[0].numpy() if torch.is_tensor(reps[0]) else reps[0])
    np.save(f"{save_dir}/{args.model_prefix}_rgb_reps.npy", reps[1].numpy() if torch.is_tensor(reps[1]) else reps[1])
    np.save(f"{save_dir}/{args.model_prefix}_depth_reps.npy", reps[2].numpy() if torch.is_tensor(reps[2]) else reps[2])
    np.save(f"{save_dir}/{args.model_prefix}_train_indices.npy", sorted_indices)
    
    print(f"\n✓ All results saved to: {save_dir}/")
    print(f"  - Model checkpoint: {args.model_prefix}.pt (in ./03_results/models/)")
    print(f"  - Training history: {args.model_prefix}_training_history.csv")
    print(f"  - Rank history: {args.model_prefix}_rank_history.csv")
    print(f"  - Representations: {args.model_prefix}_*_reps.npy")
    print(f"  - Final train loss: {final_train_loss:.6f}")
    print(f"  - Final R²: {final_rsquare}")
    
    print("\n" + "="*80)
    print("Experiment completed!")
    print("="*80)