#!/usr/bin/env python3

import torch
import torch.nn.functional as F
import numpy as np
import time
import os
from PIL import Image as PImage
from tqdm import tqdm
import warnings
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings("ignore")

from inference_VQ_Diffusion import VQ_Diffusion
from image_synthesis.modeling.codecs.image_codec.ema_vqvae import PatchVQVAE
from image_synthesis.utils.misc import instantiate_from_config


class TimingConfig:
    def __init__(self):
        self.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.IMAGE_SIZE = 256
        self.BATCH_SIZE = 1  # Use batch size 1 for per-sample timing
        self.NUM_SAMPLES = 100  # Number of samples to benchmark
        self.NUM_WARMUP = 10   # Warmup iterations
        self.NUM_REPEATS = 5   # Number of timing repeats per sample
        
        # Test dataset path - using one of the real datasets
        self.TEST_DATASET_PATH = "/datasets/imagenet256_eval"
        
        # Finetuned VQVAE weights path
        # self.FINETUNED_VQVAE_WEIGHTS = "/workspace/VQ-Diffusion/finetuned_models_BS16_L5e-05_W0.0001_E50_SEED0/vqvae_finetune_20250915_130057_enc_epochs50_final.pth"
        self.FINETUNED_VQVAE_WEIGHTS = None
        # Output directory for timing results
        self.OUTPUT_DIR = "/workspace/VQ-Diffusion/timing_results"
        
        # Latent tracer parameters
        self.LATENT_TRACER_LR = 0.01
        self.LATENT_TRACER_ITERS = 100  # Reduced from 200 to 100        
        self.ENABLE_LATENT_TRACER = True  # Set to False to skip slow latent tracer


class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, max_samples=None):
        self.transform = transform
        self.image_files = sorted([
            os.path.join(image_dir, f) for f in os.listdir(image_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ])
        if max_samples:
            self.image_files = self.image_files[:max_samples]
        if not self.image_files:
            raise FileNotFoundError(f"No images found in {image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = PImage.open(self.image_files[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]


def create_transform(config):
    return transforms.Compose([
        transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])


def initialize_vq_codec(config):
    print("Initializing VQ-Diffusion model...")
    try:
        vq_model = VQ_Diffusion(
            config='/workspace/VQ-Diffusion/configs/ithq.yaml',
            path='/checkpoints/pretrained_model/ithq_learnable.pth'
        )
        codec = vq_model.model.content_codec
        print("Successfully initialized base VQ codec")
    except Exception as e:
        print(f"Error initializing VQ codec: {e}")
        raise
    
    # Load finetuned weights if specified and exists
    if (config.FINETUNED_VQVAE_WEIGHTS is not None and 
        os.path.exists(config.FINETUNED_VQVAE_WEIGHTS)):
        print(f"Loading finetuned VQVAE weights from: {config.FINETUNED_VQVAE_WEIGHTS}")
        try:
            checkpoint = torch.load(config.FINETUNED_VQVAE_WEIGHTS, map_location='cpu')
            
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
            
            codec.load_state_dict(state_dict, strict=True)
            print("Successfully loaded finetuned VQVAE weights!")
            
        except Exception as e:
            print(f"Warning: Failed to load finetuned weights: {e}")
            print("Continuing with default pretrained weights...")
    elif config.FINETUNED_VQVAE_WEIGHTS is not None:
        print(f"Warning: Finetuned weights path does not exist: {config.FINETUNED_VQVAE_WEIGHTS}")
        print("Continuing with default pretrained weights...")
    
    return codec


def time_latent_tracer_loss(codec, images, device, config):
    """Time latent tracer loss computation for a single image."""
    # Convert images to [0,1] range for loss computation
    image_batch = (images + 1.0) / 2.0
    image_batch = torch.clamp(image_batch, 0.0, 1.0).to(device)
    
    # Get quantized states first (this is part of the setup, not timed)
    with torch.no_grad():
        h = codec.enc.encoder(images.to(device))
        f_continuous = codec.enc.quant_conv(h)
        indices = codec.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
        z_quantized = codec.enc.quantize.get_codebook_entry(
            indices.view(-1), 
            shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1])
        )
    
    times = []
    
    for repeat in range(config.NUM_REPEATS):
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.perf_counter()
        
        # Initialize optimizable latent
        fhat_optim = torch.nn.Parameter(z_quantized.clone().detach()).to(device)
        optimizer = torch.optim.Adam([fhat_optim], lr=config.LATENT_TRACER_LR)
        
        current_image = image_batch.to(device)
        
        # Optimization loop
        for iter_idx in range(config.LATENT_TRACER_ITERS):
            optimizer.zero_grad()
            
            # Decode from optimized latent
            f_optimized = codec.dec.post_quant_conv(fhat_optim)
            rec_gen_img = codec.dec.decoder(f_optimized)
            rec_gen_img = (rec_gen_img + 1.0) / 2.0
            rec_gen_img = torch.clamp(rec_gen_img, 0.0, 1.0)
            
            # Compute loss
            loss = torch.nn.functional.mse_loss(rec_gen_img, current_image)
            loss.backward()
            optimizer.step()
            
            # Learning rate decay every 50 iterations
            if iter_idx % 50 == 0 and iter_idx > 0:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] * 0.5
        
        # Final loss computation
        with torch.no_grad():
            f_optimized = codec.dec.post_quant_conv(fhat_optim)
            rec_gen_img = codec.dec.decoder(f_optimized)
            rec_gen_img = (rec_gen_img + 1.0) / 2.0
            rec_gen_img = torch.clamp(rec_gen_img, 0.0, 1.0)
            
            final_loss = torch.nn.functional.mse_loss(rec_gen_img, current_image, reduction='none')
            final_loss = final_loss.view(final_loss.size(0), -1).mean(dim=1).cpu().numpy()
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.perf_counter()
        times.append(end_time - start_time)
    
    return times


@torch.no_grad()
def time_reconstruction_loss(codec, images, device, config):
    """Time reconstruction loss computation (single pass)."""
    images = images.to(device)
    times = []
    
    for repeat in range(config.NUM_REPEATS):
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.perf_counter()
        
        # Single reconstruction pass
        h = codec.enc.encoder(images)
        f_continuous = codec.enc.quant_conv(h)
        indices = codec.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
        z_quantized = codec.enc.quantize.get_codebook_entry(
            indices.view(-1), 
            shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1])
        )
        f_quantized = codec.dec.post_quant_conv(z_quantized)
        recon_quantized = codec.dec.decoder(f_quantized) 
        recon_quantized = torch.clamp(recon_quantized, -1., 1.)
        
        # Compute reconstruction loss
        recon_loss = torch.nn.functional.mse_loss(recon_quantized, images, reduction='none')
        recon_loss = recon_loss.view(recon_loss.size(0), -1).mean(dim=1).cpu().numpy()
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.perf_counter()
        times.append(end_time - start_time)
    
    return times


@torch.no_grad()
def time_reconstruction_loss_ratio(codec, images, device, config):
    """Time reconstruction loss ratio computation (requires double pass)."""
    images = images.to(device)
    times = []
    
    for repeat in range(config.NUM_REPEATS):
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.perf_counter()
        
        # First reconstruction pass
        h = codec.enc.encoder(images)
        f_continuous = codec.enc.quant_conv(h)
        indices = codec.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
        z_quantized = codec.enc.quantize.get_codebook_entry(
            indices.view(-1), 
            shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1])
        )
        f_quantized = codec.dec.post_quant_conv(z_quantized)
        recon_quantized = codec.dec.decoder(f_quantized) 
        recon_quantized = torch.clamp(recon_quantized, -1., 1.)
        
        # Single reconstruction loss
        recon_single = torch.nn.functional.mse_loss(recon_quantized, images, reduction='none')
        recon_single = recon_single.view(recon_single.size(0), -1).mean(dim=1)
        
        # Second reconstruction pass (double)
        h2 = codec.enc.encoder(recon_quantized)
        f2_continuous = codec.enc.quant_conv(h2)
        indices2 = codec.enc.quantize.only_get_indices(f2_continuous).view(recon_quantized.shape[0], -1)
        z_quantized2 = codec.enc.quantize.get_codebook_entry(
            indices2.view(-1), 
            shape=(recon_quantized.shape[0], f2_continuous.shape[-2], f2_continuous.shape[-1])
        )
        f2_quantized = codec.dec.post_quant_conv(z_quantized2)
        recon_double_quantized = codec.dec.decoder(f2_quantized)
        recon_double_quantized = torch.clamp(recon_double_quantized, -1., 1.)
        
        # Double reconstruction loss
        recon_double = torch.nn.functional.mse_loss(recon_double_quantized, recon_quantized, reduction='none')
        recon_double = recon_double.view(recon_double.size(0), -1).mean(dim=1)
        
        # Compute ratio
        epsilon = 1e-8
        recon_ratio = recon_single / (recon_double + epsilon)
        recon_ratio = recon_ratio.cpu().numpy()
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.perf_counter()
        times.append(end_time - start_time)
    
    return times


@torch.no_grad()
def time_codebook_loss(codec, images, device, config):
    """Time codebook loss computation."""
    images = images.to(device)
    times = []
    
    for repeat in range(config.NUM_REPEATS):
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.perf_counter()
        
        # Encode to get continuous and quantized features
        h = codec.enc.encoder(images)
        f_continuous = codec.enc.quant_conv(h)
        indices = codec.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
        z_quantized = codec.enc.quantize.get_codebook_entry(
            indices.view(-1), 
            shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1])
        )
        
        # Compute codebook loss
        codebook_loss = torch.nn.functional.mse_loss(z_quantized, f_continuous, reduction='none')
        codebook_loss = codebook_loss.view(codebook_loss.size(0), -1).mean(dim=1).cpu().numpy()
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.perf_counter()
        times.append(end_time - start_time)
    
    return times


def run_timing_benchmark(config):
    """Run comprehensive timing benchmark."""
    print("=== VQ-DIFFUSION TIMING BENCHMARK ===")
    print(f"Device: {config.DEVICE}")
    print(f"Number of samples: {config.NUM_SAMPLES}")
    print(f"Number of repeats per sample: {config.NUM_REPEATS}")
    print(f"Warmup iterations: {config.NUM_WARMUP}")
    print(f"Latent tracer enabled: {config.ENABLE_LATENT_TRACER}")
    print(f"Latent tracer iterations: {config.LATENT_TRACER_ITERS}")
    
    # Verify paths exist
    if not os.path.exists(config.TEST_DATASET_PATH):
        raise FileNotFoundError(f"Test dataset path not found: {config.TEST_DATASET_PATH}")
    
    # Create output directory
    os.makedirs(config.OUTPUT_DIR, exist_ok=True)
    
    # Initialize codec
    print("\nInitializing VQ codec...")
    device = torch.device(config.DEVICE)
    codec = initialize_vq_codec(config)
    codec = codec.to(device)
    codec.eval()
    
    # Create dataset
    print("Loading test dataset...")
    transform = create_transform(config)
    try:
        dataset = CustomImageDataset(config.TEST_DATASET_PATH, transform=transform, max_samples=config.NUM_SAMPLES + config.NUM_WARMUP)
        dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0)
        print(f"Successfully loaded dataset with {len(dataset)} images")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise
    
    # Results storage
    timing_results = {
        'latent_tracer': [],
        'reconstruction_loss': [],
        'reconstruction_ratio': [],
        'codebook_loss': []
    }
    
    # Warmup phase
    print(f"\nWarming up with {config.NUM_WARMUP} samples...")
    warmup_count = 0
    for batch_images, _ in dataloader:
        if warmup_count >= config.NUM_WARMUP:
            break
        
        # Run all timing functions for warmup (don't record times)
        _ = time_reconstruction_loss(codec, batch_images, device, config)
        _ = time_codebook_loss(codec, batch_images, device, config)
        _ = time_reconstruction_loss_ratio(codec, batch_images, device, config)
        # Skip latent tracer in warmup as it's very slow
        
        warmup_count += 1
    
    # Actual timing phase
    print(f"\nRunning timing benchmark on {config.NUM_SAMPLES} samples...")
    sample_count = 0
    
    for batch_images, batch_paths in tqdm(dataloader, desc="Timing benchmark"):
        if sample_count >= config.NUM_SAMPLES:
            break
        
        # Time each loss computation
        print(f"\nTiming sample {sample_count + 1}/{config.NUM_SAMPLES}")
        
        # 1. Reconstruction loss (fastest)
        print("  - Reconstruction loss...")
        recon_times = time_reconstruction_loss(codec, batch_images, device, config)
        timing_results['reconstruction_loss'].extend(recon_times)
        
        # 2. Codebook loss 
        print("  - Codebook loss...")
        codebook_times = time_codebook_loss(codec, batch_images, device, config)
        timing_results['codebook_loss'].extend(codebook_times)
        
        # 3. Reconstruction loss ratio (requires double pass)
        print("  - Reconstruction ratio...")
        ratio_times = time_reconstruction_loss_ratio(codec, batch_images, device, config)
        timing_results['reconstruction_ratio'].extend(ratio_times)
        
        # 4. Latent tracer (now runs every sample like other losses)
        if config.ENABLE_LATENT_TRACER:
            print("  - Latent tracer...")
            try:
                tracer_times = time_latent_tracer_loss(codec, batch_images, device, config)
                timing_results['latent_tracer'].extend(tracer_times)
            except Exception as e:
                print(f"    Warning: Latent tracer failed: {e}")
        else:
            print("  - Latent tracer (skipped - disabled in config)")
        
        sample_count += 1
        
    
    print(f"\nBenchmark completed! Processed {sample_count} samples")
    return timing_results


def analyze_timing_results(timing_results, config):
    """Analyze and report timing results."""
    print("\n" + "="*60)
    print("TIMING ANALYSIS RESULTS")
    print("="*60)
    
    # Calculate statistics
    stats = {}
    for method_name, times in timing_results.items():
        if times:  # Only process if we have data
            times_ms = np.array(times) * 1000  # Convert to milliseconds
            stats[method_name] = {
                'mean_ms': np.mean(times_ms),
                'std_ms': np.std(times_ms),
                'median_ms': np.median(times_ms),
                'min_ms': np.min(times_ms),
                'max_ms': np.max(times_ms),
                'n_samples': len(times_ms)
            }
    
    if not stats:
        print("Warning: No timing data collected!")
        return {}
    
    # Print results
    print("\nPER-SAMPLE TIMING RESULTS (milliseconds):")
    print("-" * 80)
    print(f"{'Method':<25} {'Mean ± Std':<20} {'Median':<10} {'Min':<10} {'Max':<10} {'N':<8}")
    print("-" * 80)
    
    for method_name, stat in stats.items():
        method_display = method_name.replace('_', ' ').title()
        mean_std = f"{stat['mean_ms']:.2f} ± {stat['std_ms']:.2f}"
        median_str = f"{stat['median_ms']:.2f}"
        min_str = f"{stat['min_ms']:.2f}"
        max_str = f"{stat['max_ms']:.2f}"
        print(f"{method_display:<25} {mean_std:<20} {median_str:<10} "
              f"{min_str:<10} {max_str:<10} {stat['n_samples']:<8}")
    
    # Research paper format
    print("\n" + "="*60)
    print("RESEARCH PAPER FORMAT")
    print("="*60)
    
    research_report = []
    for method_name, stat in stats.items():
        method_display = method_name.replace('_', ' ').title()
        if method_name == 'latent_tracer':
            method_display = 'Latent Tracer'
        elif method_name == 'reconstruction_loss':
            method_display = 'Reconstruction Loss'
        elif method_name == 'reconstruction_ratio':
            method_display = 'Reconstruction Loss Ratio'
        elif method_name == 'codebook_loss':
            method_display = 'Codebook Loss'
        
        mean_ms = stat['mean_ms']
        std_ms = stat['std_ms']
        
        if mean_ms >= 1000:  # Convert to seconds if > 1 second
            mean_s = mean_ms / 1000
            std_s = std_ms / 1000
            time_str = f"{mean_s:.2f} ± {std_s:.2f} seconds"
        else:
            time_str = f"{mean_ms:.2f} ± {std_ms:.2f} milliseconds"
        
        research_report.append(f"{method_display}: {time_str} per sample")
    
    for line in research_report:
        print(line)
    
    # Save detailed results
    results_df = pd.DataFrame([
        {
            'Method': method_name.replace('_', ' ').title(),
            'Mean_ms': stat['mean_ms'],
            'Std_ms': stat['std_ms'],
            'Median_ms': stat['median_ms'],
            'Min_ms': stat['min_ms'],
            'Max_ms': stat['max_ms'],
            'N_samples': stat['n_samples']
        }
        for method_name, stat in stats.items()
    ])
    
    csv_path = os.path.join(config.OUTPUT_DIR, 'timing_results.csv')
    results_df.to_csv(csv_path, index=False)
    print(f"\nDetailed results saved to: {csv_path}")
    
    # Create visualization
    try:
        create_timing_visualization(timing_results, config)
    except Exception as e:
        print(f"Warning: Visualization failed: {e}")
        print("Continuing without plots...")
    
    return stats


def create_timing_visualization(timing_results, config):
    """Create timing visualization plots."""
    # Convert to milliseconds and prepare data
    plot_data = []
    for method_name, times in timing_results.items():
        if times:
            times_ms = np.array(times) * 1000
            method_display = method_name.replace('_', ' ').title()
            for time_ms in times_ms:
                plot_data.append({
                    'Method': method_display,
                    'Time (ms)': time_ms
                })
    
    if not plot_data:
        print("No timing data available for visualization")
        return
    
    df_plot = pd.DataFrame(plot_data)
    
    # Create box plot
    plt.figure(figsize=(12, 8))
    sns.boxplot(data=df_plot, x='Method', y='Time (ms)')
    plt.title('Per-Sample Timing Comparison\nVQ-VAE Loss Computations', fontsize=14, fontweight='bold')
    plt.xlabel('Loss Computation Method', fontsize=12)
    plt.ylabel('Time per Sample (milliseconds)', fontsize=12)
    plt.xticks(rotation=45)
    plt.yscale('log')  # Use log scale due to large differences
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plot_path = os.path.join(config.OUTPUT_DIR, 'timing_comparison.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create separate plot for each method (histogram)
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    method_names = list(timing_results.keys())
    for idx, method_name in enumerate(method_names):
        if idx < len(axes) and timing_results[method_name]:
            times_ms = np.array(timing_results[method_name]) * 1000
            method_display = method_name.replace('_', ' ').title()
            
            axes[idx].hist(times_ms, bins=20, alpha=0.7, edgecolor='black')
            axes[idx].set_title(f'{method_display}\nMean: {np.mean(times_ms):.2f} ± {np.std(times_ms):.2f} ms')
            axes[idx].set_xlabel('Time per Sample (ms)')
            axes[idx].set_ylabel('Frequency')
            axes[idx].grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(len(method_names), len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle('Timing Distribution for Each Loss Computation', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    hist_path = os.path.join(config.OUTPUT_DIR, 'timing_histograms.png')
    plt.savefig(hist_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Timing plots saved to: {plot_path} and {hist_path}")


def main():
    """Main function to run timing benchmark."""
    config = TimingConfig()
    
    # Run benchmark
    timing_results = run_timing_benchmark(config)
    
    # Analyze results
    stats = analyze_timing_results(timing_results, config)
    
    print(f"\nBenchmark complete! Results saved to: {config.OUTPUT_DIR}")


if __name__ == "__main__":
    main()
