# scripts/benchmark_sampling_time.py

import os
import argparse
import logging
import json
import time
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

# Import from our flowlet package
from flowlet.models import WaveletFlowMatching
from flowlet.utils import setup_logging, set_seed, get_logger

logger = get_logger(__name__)

def load_model_for_inference(checkpoint_path, config_path, device):
    """Loads the WaveletFlowMatching model from a checkpoint and config."""
    # --- Load Configuration ---
    if config_path and os.path.exists(config_path):
        with open(config_path, 'r') as f:
            model_config = json.load(f)
        logger.info(f"Loaded model configuration from: {config_path}")
    else:
        logger.error("A valid --config_path is required for model reconstruction.")
        return None

    # --- Reconstruct U-Net Args ---
    try:
        condition_vars = model_config.get('condition_vars', [])
        attention_res = tuple(map(int, model_config['unet_attention_res'].split(',')))
        channel_mult = tuple(map(int, model_config['unet_channel_mult'].split(',')))
        condition_dims_dict = {var: 1 for var in condition_vars} if condition_vars else {}
    except Exception as e:
        logger.error(f"Failed to parse model configuration: {e}")
        return None

    unet_args = {
        "in_channels": 8, "model_channels": model_config.get('unet_model_channels', 128), "out_channels": 8,
        "num_res_blocks": model_config.get('unet_num_res_blocks', 2),
        "attention_resolutions": attention_res, "dropout": model_config.get('unet_dropout', 0.1),
        "channel_mult": channel_mult, "conv_resample": True, "dims": 3,
        "use_checkpoint": False, # Important: Disable for inference
        "num_heads": model_config.get('unet_num_heads', 8),
        "num_head_channels": model_config.get('unet_num_head_channels', -1),
        "use_scale_shift_norm": True, "resblock_updown": True,
        "condition_dims": condition_dims_dict,
        "condition_embedding_dim": model_config.get('condition_embedding_dim', 512),
        "use_xformers": model_config.get('use_xformers', True),
        "use_cross_attention": not model_config.get('unet_disable_cross_attn', False) and bool(condition_dims_dict),
        "norm_num_groups": model_config.get('unet_norm_num_groups', 32), "norm_eps": 1e-6,
    }

    # --- Load Model State ---
    wfm_model = WaveletFlowMatching(u_net_args=unet_args).to(device)
    ckpt = torch.load(checkpoint_path, map_location=device)
    state_dict = ckpt.get("flow_net_state_dict", ckpt.get("model_state_dict", ckpt))
    if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
        state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
    wfm_model.flow_net.load_state_dict(state_dict)
    wfm_model.eval()
    logger.info(f"Model loaded successfully from epoch {ckpt.get('epoch', -1)+1}")
    
    return wfm_model, model_config

def main():
    parser = argparse.ArgumentParser(
        description="Benchmark the sampling time of a FlowLet model for various step counts.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.pth).")
    parser.add_argument("--config_path", type=str, required=True, help="Path to the model configuration JSON file.")
    parser.add_argument("--output_csv", type=str, required=True, help="Path to save the benchmark results CSV file.")
    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use.")
    
    args = parser.parse_args()

    # --- Setup ---
    setup_logging(log_dir=".", filename_prefix="flowlet_benchmark_time")
    set_seed(42)
    device = torch.device(args.device)
    if device.type == 'cpu':
        logger.warning("Benchmarking on CPU. Timings will not be representative of GPU performance.")
    
    # --- Load Model ---
    wfm_model, model_config = load_model_for_inference(args.checkpoint_path, args.config_path, device)
    if wfm_model is None:
        return
        
    # --- Benchmark Parameters ---
    steps_to_benchmark = [1, 2, 5, 10, 200]
    num_warmup_samples = 4
    num_timed_samples = 15
    
    # Prepare a dummy condition for generation. We use a normalized age of 0.5 (mid-range).
    dummy_conditions = {'Age': torch.tensor([[0.5]], device=device)}
    model_input_size = tuple(model_config['model_input_size'])

    benchmark_results = []

    # --- Main Loop ---
    for num_steps in steps_to_benchmark:
        logger.info(f"--- Benchmarking {num_steps} steps ---")
        
        # Set the number of flow steps for the model's sampler
        wfm_model.num_flow_steps = num_steps
        
        # 1. GPU Warm-up Phase
        logger.info(f"  Performing {num_warmup_samples} warm-up generations...")
        with torch.no_grad():
            for _ in tqdm(range(num_warmup_samples), desc="Warm-up", leave=False):
                _ = wfm_model.sample(
                    num_samples=1, 
                    model_output_size=model_input_size, 
                    conditions_dict=dummy_conditions
                )
        
        # 2. Timed Measurement Phase
        timings = []
        logger.info(f"  Performing {num_timed_samples} timed generations...")
        with torch.no_grad():
            for _ in tqdm(range(num_timed_samples), desc="Benchmarking", leave=False):
                if device.type == 'cuda':
                    torch.cuda.synchronize() # Wait for GPU to be idle before starting timer
                
                start_time = time.perf_counter()
                
                _ = wfm_model.sample(
                    num_samples=1, 
                    model_output_size=model_input_size, 
                    conditions_dict=dummy_conditions
                )
                
                if device.type == 'cuda':
                    torch.cuda.synchronize() # Wait for the generation to complete before stopping timer
                
                end_time = time.perf_counter()
                timings.append(end_time - start_time)
        
        # 3. Calculate and store results
        mean_time = np.mean(timings)
        std_time = np.std(timings)
        
        logger.info(f"  Result: {mean_time:.4f} ± {std_time:.4f} seconds per sample.")
        
        benchmark_results.append({
            "Steps": num_steps,
            "Mean_Time_s": mean_time,
            "Std_Time_s": std_time,
            "Samples_Timed": num_timed_samples
        })

    # --- Save and Display Results ---
    df = pd.DataFrame(benchmark_results)
    
    try:
        output_dir = os.path.dirname(args.output_csv)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        df.to_csv(args.output_csv, index=False, float_format='%.6f')
        logger.info(f"Benchmark results saved to: {args.output_csv}")
    except Exception as e:
        logger.error(f"Failed to save CSV file: {e}")
        
    print("\n--- Sampling Time Benchmark Summary ---")
    print(df.to_string(index=False))
    print("---------------------------------------")

if __name__ == "__main__":
    main()