#!/usr/bin/env python
"""Test speed vs K using actual checkpoint and test data."""

import torch
import numpy as np
from src.utils import DataAttr
from src.models.ace import InferenceEngine2, AmortizedConditioningEngine
from src.models.modules import Embedder, Transformer, MixtureGaussian
from pathlib import Path
import time
import matplotlib.pyplot as plt
import argparse
from huggingface_hub import snapshot_download

def load_model_from_checkpoint(checkpoint_path: str, device) -> AmortizedConditioningEngine:
    """Load model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Extract configuration
    config = checkpoint.get('config', {})
    model_config = config.get('model', {})
    
    # Create model components
    embedder = Embedder(
        dim_x=model_config.get('dim_x', 1),
        dim_y=model_config.get('dim_y', 1),
        hidden_dim=model_config.get('embedder', {}).get('hidden_dim', 256),
        out_dim=model_config.get('dim_model', 128),
        depth=model_config.get('embedder', {}).get('depth', 3)
    )
    
    backbone = Transformer(
        num_layers=model_config.get('backbone', {}).get('num_layers', 6),
        dim_model=model_config.get('dim_model', 128),
        num_head=model_config.get('backbone', {}).get('num_heads', 4),
        dim_feedforward=model_config.get('backbone', {}).get('dim_feedforward', 256),
        dropout=model_config.get('backbone', {}).get('dropout', 0.0)
    )
    
    head = MixtureGaussian(
        dim_y=model_config.get('dim_y', 1),
        dim_model=model_config.get('dim_model', 128),
        dim_feedforward=model_config.get('head', {}).get('dim_feedforward', 256),
        num_components=model_config.get('head', {}).get('num_components', 20)
    )
    
    # Create model
    model = AmortizedConditioningEngine(
        embedder=embedder,
        backbone=backbone,
        head=head,
        max_buffer_size=model_config.get('max_buffer_size', 8),
        num_target_points=model_config.get('num_target_points', 256),
        targets_block_size_for_buffer_attend=model_config.get('targets_block_size_for_buffer_attend', 4)
    )
    
    # Load state dict - handle compiled model checkpoints with _orig_mod prefix
    state_dict = checkpoint['model_state_dict']
    
    # Remove _orig_mod. prefix if present (from torch.compile)
    if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
        state_dict = {key.replace('_orig_mod.', ''): value 
                     for key, value in state_dict.items()}
    
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded successfully (epoch {checkpoint.get('epoch', 'unknown')}, step {checkpoint.get('step', 'unknown')})")
    return model

def sample_sequence_K1_timed(engine, batch: DataAttr, checkpoints):
    """Modified sample_sequence_K1 with timing at checkpoints."""
    B, T = batch.xt.shape[0], batch.xt.shape[1]
    device = batch.xt.device
    
    # Track timing with CUDA events if on GPU
    times = []
    checkpoint_idx = 0
    
    if device.type == 'cuda':
        start_event = torch.cuda.Event(enable_timing=True)
        checkpoint_events = [torch.cuda.Event(enable_timing=True) for _ in checkpoints]
        torch.cuda.synchronize()
        start_event.record()
    else:
        start_time = time.perf_counter()
    
    # Copy from sample_sequence_K1 but add timing
    engine.store_context_embeddings(batch)
    max_seq = engine.context_embeddings.shape[1] + T
    engine.init_kv_cache(B, max_seq, device=batch.xt.device)
    
    predicted_positions = torch.zeros(
        B, T, batch.xt.shape[2], device=batch.xt.device, dtype=batch.xt.dtype
    )
    predicted_values = torch.zeros(
        B, T, 1, device=batch.xt.device, dtype=batch.xt.dtype
    )
    
    for t in range(T):
        # Process one target at a time (K=1)
        start_idx = t
        end_idx = t + 1
        batch_K = 1
    
        N_C = engine.context_embeddings.shape[1]
        
        # Debug: Print context size at key checkpoints
        if t in [0, 15, 31, 63, 127]:
            print(f"    [K=1] t={t+1}, context_size={N_C}")
        
        SA_mask = engine._get_cached_selfattn_mask(N_C, batch.xt.device)
        engine.prefill_kv_cache(SA_mask)
        
        from src.utils import fetch_next_query_batch
        query = fetch_next_query_batch(batch, start_idx, batch_K)
        batch_predictions = engine.autoregressive_decode(query)
        
        # Store predictions - yc contains the predicted values
        # For positions, use the target positions we're predicting at
        predicted_positions[:, start_idx:end_idx, :] = batch.xt[:, start_idx:end_idx, :]
        predicted_values[:, start_idx:end_idx, :] = batch_predictions.yc
        
        # Update context embeddings like the original
        engine.update_context_embeddings(batch_predictions)
        
        # Check if we hit a checkpoint
        if checkpoint_idx < len(checkpoints) and (t + 1) == checkpoints[checkpoint_idx]:
            if device.type == 'cuda':
                checkpoint_events[checkpoint_idx].record()
            else:
                elapsed = time.perf_counter() - start_time
                times.append(elapsed)
            checkpoint_idx += 1
    
    # Synchronize and get timings if on GPU
    if device.type == 'cuda':
        torch.cuda.synchronize()
        for i, event in enumerate(checkpoint_events[:checkpoint_idx]):
            elapsed_ms = start_event.elapsed_time(event)
            times.append(elapsed_ms / 1000.0)  # Convert to seconds
    
    result = DataAttr(
        xc=batch.xc,
        yc=predicted_values,
        xb=batch.xb,
        yb=batch.yb,
        xt=predicted_positions,
        yt=batch.yt
    )
    
    return result, times

def sample_sequence_timed(engine, batch: DataAttr, K: int, checkpoints):
    """Modified sample_sequence with timing at checkpoints."""
    B, T = batch.xt.shape[0], batch.xt.shape[1]
    device = batch.xt.device
    
    # Track timing with CUDA events if on GPU
    times = []
    checkpoint_idx = 0
    
    if device.type == 'cuda':
        start_event = torch.cuda.Event(enable_timing=True)
        checkpoint_events = [torch.cuda.Event(enable_timing=True) for _ in checkpoints]
        torch.cuda.synchronize()
        start_event.record()
    else:
        start_time = time.perf_counter()
    
    # Initialize like sample_sequence
    engine.store_context_embeddings(batch)
    max_seq = engine.context_embeddings.shape[1] + T
    engine.init_kv_cache(B, max_seq, device=batch.xt.device)
    engine.offset = torch.zeros([], dtype=torch.int64, device=batch.xt.device)
    
    predicted_positions = torch.zeros(
        B, T, batch.xt.shape[2], device=batch.xt.device, dtype=batch.xt.dtype
    )
    predicted_values = torch.zeros(
        B, T, 1, device=batch.xt.device, dtype=batch.xt.dtype
    )
    
    # Generate in batches of K
    num_batches = (T + K - 1) // K
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * K
        end_idx = min(start_idx + K, T)
        batch_K = end_idx - start_idx
    
        N_C = engine.context_embeddings.shape[1]
        
        # Debug: Print context size at key checkpoints
        if end_idx in [16, 32, 64, 128]:
            print(f"    [K={K}] generated up to t={end_idx}, context_size={N_C}")
        
        SA_mask = engine._get_cached_selfattn_mask(N_C, batch.xt.device)
        engine.prefill_kv_cache(SA_mask)
        
        # Process K targets one by one within this batch
        from src.utils import fetch_next_query
        batch_predictions = None
        for k in range(batch_K):
            query = fetch_next_query(batch, start_idx + k)
            if k == 0:
                prediction = engine.autoregressive_decode(query)
                batch_predictions = prediction
            else:
                from src.utils import concatenate_batches
                prediction = engine.autoregressive_decode(query)
                batch_predictions = concatenate_batches(batch_predictions, prediction)
        
        # Store predictions
        predicted_positions[:, start_idx:end_idx, :] = batch_predictions.xc
        predicted_values[:, start_idx:end_idx, :] = batch_predictions.yc
        
        # Update context embeddings with all predictions from this batch
        engine.update_context_embeddings(batch_predictions)
        
        # Check if we hit any checkpoints
        while checkpoint_idx < len(checkpoints) and end_idx >= checkpoints[checkpoint_idx]:
            if device.type == 'cuda':
                checkpoint_events[checkpoint_idx].record()
            else:
                elapsed = time.perf_counter() - start_time
                times.append(elapsed)
            checkpoint_idx += 1
    
    # Synchronize and get timings if on GPU
    if device.type == 'cuda':
        torch.cuda.synchronize()
        for i, event in enumerate(checkpoint_events[:checkpoint_idx]):
            elapsed_ms = start_event.elapsed_time(event)
            times.append(elapsed_ms / 1000.0)  # Convert to seconds
    
    result = DataAttr(
        xc=batch.xc,
        yc=predicted_values,
        xb=batch.xb,
        yb=batch.yb,
        xt=predicted_positions,
        yt=batch.yt
    )
    
    return result, times

# Parse command line arguments
parser = argparse.ArgumentParser(description='Test ACE model speed with different K values')
parser.add_argument('--precision', type=str, required=True,
                    choices=['cpu', 'float32', 'bfloat16'],
                    help='Precision to use: cpu (float32), float32 (GPU), or bfloat16 (GPU)')
parser.add_argument('--checkpoint', type=str, required=True,
                    help='Path to model checkpoint')
parser.add_argument('--data', type=str, required=True,
                    help='Path to test data file')
parser.add_argument('--k-values', type=int, nargs='+', default=[4, 8, 16],
                    help='K values to test (default: 4 8 16). K=1 is always included.')
args = parser.parse_args()

# Determine device and dtype based on precision argument
if args.precision == 'cpu':
    device = torch.device('cpu')
    dtype = torch.float32
    print(f"Running on CPU with float32 precision")
elif args.precision == 'float32':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32
    print(f"Running on {device} with float32 precision")
else:  # bfloat16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16
    print(f"Running on {device} with bfloat16 precision")

# Load the model checkpoint
checkpoint_path = args.checkpoint
print(f"Loading model from {checkpoint_path}")
model = load_model_from_checkpoint(checkpoint_path, device=device)

# Convert to specified dtype
model = model.to(dtype)
print(f"Model loaded and converted to {dtype} on {device}")

# Load test data - handle both local files and HuggingFace datasets
if '/' in args.data and not Path(args.data).exists():
    # Looks like a HuggingFace repo ID (username/dataset-name)
    print(f"\nDetected HuggingFace dataset: {args.data}")
    print("Downloading dataset...")
    dataset_path = snapshot_download(
        repo_id=args.data,
        repo_type="dataset",
        cache_dir="./.cache/huggingface"
    )
    # Find the first .pt file in the downloaded directory
    dataset_path = Path(dataset_path)
    pt_files = list(dataset_path.glob("**/*.pt"))
    if not pt_files:
        raise ValueError(f"No .pt files found in downloaded dataset {args.data}")
    # Use the first .pt file found
    test_data_path = pt_files[0]
    print(f"Using file: {test_data_path.name}")
    if len(pt_files) > 1:
        print(f"Note: Found {len(pt_files)} .pt files, using the first one")
else:
    # Local file path
    test_data_path = Path(args.data)
    if not test_data_path.exists():
        raise FileNotFoundError(f"Data file not found: {test_data_path}")

print(f"Loading test data from {test_data_path}")
test_data = torch.load(test_data_path)

# Use first batch from the test data
print("\nUsing first batch from test data")
selected_batch = test_data[0]
batch_idx = 0

nc = selected_batch['xc'].shape[1]
nt = selected_batch['xt'].shape[1]  # Second dimension is number of targets!

print(f"\nUsing batch {batch_idx}:")
print(f"  Context: {nc} points")
print(f"  Targets: {nt} points")

# Use first function but create permutations of target order
first_xc = selected_batch['xc'][0:1].float()  # [1, nc, 1]
first_yc = selected_batch['yc'][0:1].float()  # [1, nc, 1]
first_xt = selected_batch['xt'][0].float()  # [nt, 1]
first_yt = selected_batch['yt'][0].float()  # [nt, 1]

# Use all available targets from the data
max_targets = nt  # Use all targets from the batch

# Create permutations and store the permutation indices
num_permutations = 128  # Running 128 copies in parallel for better GPU utilization
permuted_xt = []
permuted_yt = []
perm_indices = []  # Store permutation indices for unpermuting later

for _ in range(num_permutations):
    perm = torch.randperm(max_targets)
    perm_indices.append(perm)  # Store the permutation
    permuted_xt.append(first_xt[perm].unsqueeze(0))  # [1, nt, 1]
    permuted_yt.append(first_yt[perm].unsqueeze(0))  # [1, nt, 1]

# Create batch (convert to specified dtype and move to device)
batch = DataAttr(
    xc=first_xc.repeat(num_permutations, 1, 1).to(device).to(dtype),
    yc=first_yc.repeat(num_permutations, 1, 1).to(device).to(dtype),
    xb=torch.empty(num_permutations, 0, 1, dtype=dtype, device=device),  # No buffer
    yb=torch.empty(num_permutations, 0, 1, dtype=dtype, device=device),
    xt=torch.cat(permuted_xt, dim=0).to(device).to(dtype),
    yt=torch.cat(permuted_yt, dim=0).to(device).to(dtype),
)

print(f"\nBatch prepared:")
print(f"  Batch size: {num_permutations} permutations")
print(f"  Context: {batch.xc.shape[1]} points")
print(f"  Targets: {batch.xt.shape[1]} points (limited to {max_targets})")

# Create inference engine
engine = InferenceEngine2.from_trained_model(model)

# Test different K values (K=1 is always included, plus user-specified values)
K_values = [1] + args.k_values
# Every 16 targets up to max_targets
target_points = list(range(16, max_targets + 1, 16))

# Store results (including predictions for visualization)
results = {k: {'targets': [], 'times': [], 'predictions': None} for k in K_values}

print("\n" + "="*60)
print("Testing different K values with SINGLE generation run")
print("Using actual checkpoint and test data")
print("="*60)
print(f"Generating {max_targets} targets ONCE, recording actual time at checkpoints")
print("="*60)

# Test each K value
for K in K_values:
    print(f"\nTesting K={K}...")
    
    # Skip warmup for CPU, do multiple warmup runs for GPU
    if device.type == 'cpu':
        print(f"  Skipping warmup on CPU...")
    else:
        print(f"  Warming up with 3 runs...")
        with torch.no_grad():
            for warmup_iter in range(3):
                if K == 1:
                    _ = engine.sample_sequence_K1(batch)
                else:
                    _ = engine.sample_sequence(batch, K=K)
                torch.cuda.synchronize()
    
    # Run timed generation
    print(f"  Running timed generation...")
    with torch.no_grad():
        if K == 1:
            predictions, checkpoint_times = sample_sequence_K1_timed(engine, batch, target_points)
        else:
            predictions, checkpoint_times = sample_sequence_timed(engine, batch, K, target_points)
    
    # Print results
    for t, elapsed in zip(target_points, checkpoint_times):
        throughput = (t * num_permutations) / elapsed
        print(f"  t={t:3d}: {elapsed:6.3f}s, throughput: {throughput:7.1f} points/s")
    
    results[K]['targets'] = target_points
    results[K]['times'] = checkpoint_times
    results[K]['predictions'] = predictions  # Store predictions for visualization

# Create aesthetic visualization of predictions for all K values
print("\n" + "="*60)
print("Creating aesthetic visualization of predictions for all K values")
print("="*60)

# Use the predictions we already generated during timing
print("Using predictions from timed runs for visualization...")

# Convert to numpy for plotting
xc_np = batch.xc[0, :, 0].cpu().numpy()
yc_np = batch.yc[0, :, 0].cpu().numpy()

# Get original target order (before permutation)
xt_original = first_xt[:, 0].cpu().numpy()
yt_original = first_yt[:, 0].cpu().numpy()

# Sort by x value for cleaner plotting
sort_idx = np.argsort(xt_original)
xt_sorted = xt_original[sort_idx]
yt_sorted = yt_original[sort_idx]

# Process predictions for each K value
predictions_by_k = {}
for K in K_values:
    all_pred_np = []
    for i in range(num_permutations):
        # Get predictions for this permutation (from timed results)
        pred_permuted = results[K]['predictions'].yc[i, :, 0].cpu()  # [nt] - keep as tensor for now
        
        # Unpermute: create inverse permutation to restore original order
        perm = perm_indices[i]
        inv_perm = torch.zeros_like(perm)
        inv_perm[perm] = torch.arange(max_targets)
        
        # Apply inverse permutation to get back original order
        pred_original_order = pred_permuted[inv_perm].numpy()
        
        # Now sort by x values for plotting
        pred_sorted = pred_original_order[sort_idx]
        all_pred_np.append(pred_sorted)
    
    all_pred_array = np.array(all_pred_np)  # [num_permutations, nt]
    
    # Calculate statistics
    predictions_by_k[K] = {
        'all': all_pred_array,
        'mean': np.mean(all_pred_array, axis=0),
        'median': np.median(all_pred_array, axis=0),
        'q10': np.percentile(all_pred_array, 10, axis=0),
        'q90': np.percentile(all_pred_array, 90, axis=0),
        'q25': np.percentile(all_pred_array, 25, axis=0),
        'q75': np.percentile(all_pred_array, 75, axis=0)
    }

# Create aesthetic plot with 4 subplots for different K values
plt.style.use('seaborn-v0_8-darkgrid')
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

# Color palette
context_color = '#E74C3C'  # Elegant red
true_color = '#2C3E50'  # Dark blue-gray
mean_color = '#3498DB'  # Bright blue
sample_color = '#95A5A6'  # Light gray
band_color_outer = '#85C1E9'  # Light blue
band_color_inner = '#5DADE2'  # Medium blue

# Plot for each K value
for idx, K in enumerate(K_values):
    ax = axes[idx]
    pred_data = predictions_by_k[K]
    
    # Plot uncertainty bands
    ax.fill_between(xt_sorted, pred_data['q10'], pred_data['q90'], 
                    alpha=0.2, color=band_color_outer, 
                    label='10-90% Percentile', zorder=1)
    ax.fill_between(xt_sorted, pred_data['q25'], pred_data['q75'], 
                    alpha=0.3, color=band_color_inner, 
                    label='25-75% Percentile', zorder=2)
    
    # Plot a subset of individual predictions (20 samples)
    sample_indices = np.linspace(0, num_permutations-1, 20, dtype=int)
    for sample_idx in sample_indices[:19]:  # Plot 19 without label
        ax.plot(xt_sorted, pred_data['all'][sample_idx], 
               color=sample_color, alpha=0.3, linewidth=0.8, zorder=3)
    # Add one with label for legend
    ax.plot(xt_sorted, pred_data['all'][sample_indices[-1]], 
           color=sample_color, alpha=0.3, linewidth=0.8, 
           label='Individual Samples (20 shown)', zorder=3)
    
    # Plot true function
    ax.plot(xt_sorted, yt_sorted, 
           color=true_color, linewidth=2.5, 
           label='True Function', zorder=5, linestyle='-', alpha=0.9)
    
    # Plot mean prediction
    ax.plot(xt_sorted, pred_data['mean'], 
           color=mean_color, linewidth=2.5, 
           label=f'Mean Prediction', 
           zorder=6, linestyle='--')
    
    # Plot context points
    ax.scatter(xc_np, yc_np, 
              c=context_color, s=80, 
              label=f'Context ({nc} pts)', 
              zorder=10, edgecolors='white', linewidth=1.5,
              marker='o')
    
    # Calculate RMSE
    rmse = np.sqrt(np.mean((pred_data['mean'] - yt_sorted)**2))
    
    # Styling
    ax.set_xlabel('Input (x)', fontsize=11)
    ax.set_ylabel('Output (y)', fontsize=11)
    ax.set_title(f'K={K} (RMSE: {rmse:.3f})', fontsize=13, fontweight='bold')
    
    # Legend only on first subplot
    if idx == 0:
        ax.legend(loc='best', fontsize=9, frameon=True, fancybox=True, shadow=True)
    
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    ax.set_axisbelow(True)
    ax.set_facecolor('#F8F9FA')

# Overall title
fig.suptitle(f'ACE Model Predictions: {nc} Context → {max_targets} Target Points\nComparing Different K Values (128 Permutations)', 
            fontsize=16, fontweight='bold', y=1.02)

# Tight layout
plt.tight_layout()

# Save with high DPI
plt.savefig('ace_predictions_by_k.png', dpi=200, bbox_inches='tight', 
           facecolor='white', edgecolor='none')
print("Aesthetic plot saved to 'ace_predictions_by_k.png'")

# Create the plot with time on y-axis
plt.figure(figsize=(10, 6))

colors = ['red', 'blue', 'green', 'orange', 'purple']
markers = ['o', 's', '^', 'D', 'v']

for i, K in enumerate(K_values):
    targets = results[K]['targets']
    times = results[K]['times']
    
    plt.plot(targets, times, 
             marker=markers[i], 
             color=colors[i], 
             linewidth=2, 
             markersize=8,
             label=f'K={K}')

plt.xlabel('Number of Targets', fontsize=12)
plt.ylabel('Time (seconds)', fontsize=12)
plt.title('ACE Inference Speed vs K (Checkpoint Model)\n(Lower is Better)', fontsize=14)
plt.legend(fontsize=11, loc='best')
plt.grid(True, alpha=0.3)
plt.xticks(target_points)

plt.tight_layout()
plt.savefig('speed_vs_k_checkpoint.png', dpi=150)
print(f"\nPlot saved to 'speed_vs_k_checkpoint.png'")

# Print summary table
print("\n" + "="*60)
print("Summary Table: ACTUAL Time (seconds) from single generation run")
print("="*60)
print(f"{'Targets':<10}", end='')
for K in K_values:
    print(f"K={K:<2d}      ", end='')
print()
print("-" * 60)

for i, nt in enumerate(target_points):
    print(f"{nt:<10}", end='')
    for K in K_values:
        print(f"{results[K]['times'][i]:8.3f}  ", end='')
    print()

# Calculate and print speedup relative to K=1
print("\n" + "="*60)
print("Speedup relative to K=1")
print("="*60)
print(f"{'Targets':<10}", end='')
for K in K_values[1:]:  # Skip K=1
    print(f"K={K:<2d}       ", end='')
print()
print("-" * 60)

for i, nt in enumerate(target_points):
    print(f"{nt:<10}", end='')
    for K in K_values[1:]:  # Skip K=1
        speedup = results[1]['times'][i] / results[K]['times'][i]
        print(f"{speedup:8.2f}x  ", end='')
    print()

# Show which K is best for each target count
print("\n" + "="*60)
print("Optimal K for each target count")
print("="*60)
print(f"{'Targets':<10} {'Best K':<10} {'Time (s)':<12} {'vs K=1':<10}")
print("-" * 60)

for i, nt in enumerate(target_points):
    times_for_target = [(K, results[K]['times'][i]) for K in K_values]
    best_k, best_time = min(times_for_target, key=lambda x: x[1])
    k1_time = results[1]['times'][i]
    speedup = k1_time / best_time
    print(f"{nt:<10} {best_k:<10} {best_time:<12.3f} {speedup:<10.2f}x")

print("\nScript complete!")