#!/usr/bin/env python3
"""
Performance Benchmarking Tool for Multi-GPU Training

This script provides comprehensive benchmarking for different multi-GPU strategies:
1. Single GPU training
2. DataParallel (DP) training
3. DistributedDataParallel (DDP) training

It measures training speed, GPU utilization, and memory usage to help you choose
the best strategy for your specific model and hardware setup.

Note: this testing script was written by Claude-4-sonnet. It was used to benchmark the performance 
        of the transformer model on multiple GPUs. The results in the paper were generated without
        multi-GPU because communication overhead slowed down training considerably.
"""

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import time
import os
import json
import numpy as np
from datetime import datetime
from contextlib import contextmanager
import psutil
import threading
from collections import defaultdict

import model.transformer as tf


class PerformanceMonitor:
    """Monitor GPU utilization and memory usage during training."""
    
    def __init__(self, interval=0.5):
        self.interval = interval
        self.monitoring = False
        self.stats = defaultdict(list)
        self.monitor_thread = None
    
    def start_monitoring(self):
        """Start monitoring GPU stats."""
        self.monitoring = True
        self.stats.clear()
        self.monitor_thread = threading.Thread(target=self._monitor_loop)
        self.monitor_thread.start()
    
    def stop_monitoring(self):
        """Stop monitoring and return collected stats."""
        self.monitoring = False
        if self.monitor_thread:
            self.monitor_thread.join()
        return dict(self.stats)
    
    def _monitor_loop(self):
        """Monitor loop that runs in a separate thread."""
        while self.monitoring:
            if torch.cuda.is_available():
                for i in range(torch.cuda.device_count()):
                    # Memory stats
                    memory_allocated = torch.cuda.memory_allocated(i) / (1024**3)  # GB
                    memory_cached = torch.cuda.memory_reserved(i) / (1024**3)  # GB
                    
                    self.stats[f'gpu_{i}_memory_allocated'].append(memory_allocated)
                    self.stats[f'gpu_{i}_memory_cached'].append(memory_cached)
                    
                    # Utilization (requires nvidia-ml-py3: pip install nvidia-ml-py3)
                    try:
                        import pynvml
                        pynvml.nvmlInit()
                        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                        self.stats[f'gpu_{i}_utilization'].append(util.gpu)
                    except ImportError:
                        # nvidia-ml-py3 not available
                        pass
            
            time.sleep(self.interval)


class BenchmarkConfig:
    """Configuration for benchmarking experiments."""
    
    def __init__(self):
        # Model parameters
        self.vocab_size = 1000
        self.d_model = 256
        self.n_layers = 4
        self.n_heads = 8
        self.max_length = 100
        self.num_classes = 2
        self.dropout_rate = 0.1
        
        # Training parameters
        self.batch_size = 32
        self.num_epochs = 10
        self.lr = 0.001
        self.weight_decay = 0.01
        
        # Data parameters
        self.train_samples = 2000
        self.val_samples = 500
        
        # Benchmark parameters
        self.warmup_epochs = 2  # Epochs to skip for timing
        self.measure_epochs = 5   # Epochs to measure


def create_dummy_dataset(config):
    """Create dummy datasets for benchmarking."""
    train_data = torch.randint(0, config.vocab_size, (config.train_samples, config.max_length))
    train_labels = torch.randint(0, config.num_classes, (config.train_samples,))
    train_masks = torch.ones_like(train_data)
    
    val_data = torch.randint(0, config.vocab_size, (config.val_samples, config.max_length))
    val_labels = torch.randint(0, config.num_classes, (config.val_samples,))
    val_masks = torch.ones_like(val_data)
    
    train_dataset = torch.utils.data.TensorDataset(train_data, train_masks, train_labels)
    val_dataset = torch.utils.data.TensorDataset(val_data, val_masks, val_labels)
    
    return train_dataset, val_dataset


def benchmark_single_gpu(config):
    """Benchmark single GPU training."""
    print("\n" + "="*50)
    print("BENCHMARKING: Single GPU Training")
    print("="*50)
    
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    # Create model
    model = tf.SimpleTransformer(
        vocab_size=config.vocab_size,
        d_model=config.d_model,
        n_layers=config.n_layers,
        n_heads=config.n_heads,
        max_length=config.max_length,
        num_classes=config.num_classes,
        dropout_rate=config.dropout_rate
    )
    model.apply(tf.init_weights)
    model = model.to(device)
    
    # Create datasets
    train_dataset, val_dataset = create_dummy_dataset(config)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    
    # Setup training
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # Start monitoring
    monitor = PerformanceMonitor()
    monitor.start_monitoring()
    
    # Training loop
    epoch_times = []
    start_time = time.time()
    
    for epoch in range(config.num_epochs):
        epoch_start = time.time()
        model.train()
        
        for batch_idx, (input_ids, attention_mask, labels) in enumerate(train_loader):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
        
        if epoch >= config.warmup_epochs:
            print(f"Epoch {epoch}: {epoch_time:.2f}s")
    
    total_time = time.time() - start_time
    
    # Stop monitoring
    stats = monitor.stop_monitoring()
    
    # Calculate metrics
    measured_epochs = epoch_times[config.warmup_epochs:]
    avg_epoch_time = np.mean(measured_epochs) if measured_epochs else 0
    
    results = {
        'strategy': 'single_gpu',
        'total_time': total_time,
        'avg_epoch_time': avg_epoch_time,
        'epochs_measured': len(measured_epochs),
        'gpu_count': 1,
        'effective_batch_size': config.batch_size,
        'model_params': sum(p.numel() for p in model.parameters()),
        'peak_memory_gb': max(stats.get('gpu_0_memory_allocated', [0])) if stats else 0,
        'avg_gpu_util': np.mean(stats.get('gpu_0_utilization', [0])) if stats.get('gpu_0_utilization') else 0
    }
    
    print(f"Results: {avg_epoch_time:.2f}s/epoch, Peak Memory: {results['peak_memory_gb']:.2f}GB")
    return results


def benchmark_dataparallel(config):
    """Benchmark DataParallel training."""
    print("\n" + "="*50)
    print("BENCHMARKING: DataParallel Training")
    print("="*50)
    
    if torch.cuda.device_count() < 2:
        print("Skipping DataParallel benchmark - requires at least 2 GPUs")
        return None
    
    device = 'cuda:0'
    gpu_count = torch.cuda.device_count()
    
    # Create model
    model = tf.SimpleTransformer(
        vocab_size=config.vocab_size,
        d_model=config.d_model,
        n_layers=config.n_layers,
        n_heads=config.n_heads,
        max_length=config.max_length,
        num_classes=config.num_classes,
        dropout_rate=config.dropout_rate
    )
    model.apply(tf.init_weights)
    model = model.to(device)
    
    # Wrap with DataParallel
    model = torch.nn.DataParallel(model)
    
    # Create datasets with larger batch size for multi-GPU
    effective_batch_size = config.batch_size * gpu_count
    train_dataset, val_dataset = create_dummy_dataset(config)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=effective_batch_size, shuffle=True, 
        num_workers=8, pin_memory=True
    )
    
    # Setup training
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr * gpu_count, weight_decay=config.weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    # Start monitoring
    monitor = PerformanceMonitor()
    monitor.start_monitoring()
    
    # Training loop
    epoch_times = []
    start_time = time.time()
    
    for epoch in range(config.num_epochs):
        epoch_start = time.time()
        model.train()
        
        for batch_idx, (input_ids, attention_mask, labels) in enumerate(train_loader):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)
        
        if epoch >= config.warmup_epochs:
            print(f"Epoch {epoch}: {epoch_time:.2f}s")
    
    total_time = time.time() - start_time
    
    # Stop monitoring
    stats = monitor.stop_monitoring()
    
    # Calculate metrics
    measured_epochs = epoch_times[config.warmup_epochs:]
    avg_epoch_time = np.mean(measured_epochs) if measured_epochs else 0
    
    # Calculate peak memory across all GPUs
    peak_memory = 0
    avg_utilization = 0
    for i in range(gpu_count):
        gpu_memory = stats.get(f'gpu_{i}_memory_allocated', [0])
        if gpu_memory:
            peak_memory = max(peak_memory, max(gpu_memory))
        
        gpu_util = stats.get(f'gpu_{i}_utilization', [0])
        if gpu_util:
            avg_utilization += np.mean(gpu_util)
    
    avg_utilization /= gpu_count
    
    results = {
        'strategy': 'dataparallel',
        'total_time': total_time,
        'avg_epoch_time': avg_epoch_time,
        'epochs_measured': len(measured_epochs),
        'gpu_count': gpu_count,
        'effective_batch_size': effective_batch_size,
        'model_params': sum(p.numel() for p in model.parameters()),
        'peak_memory_gb': peak_memory,
        'avg_gpu_util': avg_utilization
    }
    
    print(f"Results: {avg_epoch_time:.2f}s/epoch, Peak Memory: {results['peak_memory_gb']:.2f}GB")
    return results


def run_comprehensive_benchmark():
    """Run comprehensive benchmarks comparing different strategies."""
    print("🚀 Starting Comprehensive Multi-GPU Performance Benchmark")
    print("=" * 70)
    
    # Check GPU availability
    if not torch.cuda.is_available():
        print("❌ CUDA not available. Cannot run benchmarks.")
        return
    
    gpu_count = torch.cuda.device_count()
    print(f"📊 Found {gpu_count} GPU(s)")
    
    for i in range(gpu_count):
        props = torch.cuda.get_device_properties(i)
        print(f"   GPU {i}: {props.name} ({props.total_memory / (1024**3):.1f} GB)")
    
    config = BenchmarkConfig()
    results = []
    
    print(f"\n🔧 Benchmark Configuration:")
    print(f"   Model: {config.n_layers} layers, {config.d_model} dim, {config.n_heads} heads")
    print(f"   Data: {config.train_samples} samples, batch size {config.batch_size}")
    print(f"   Training: {config.num_epochs} epochs ({config.warmup_epochs} warmup)")
    
    # Single GPU benchmark
    single_gpu_results = benchmark_single_gpu(config)
    if single_gpu_results:
        results.append(single_gpu_results)
    
    # DataParallel benchmark
    if gpu_count > 1:
        dp_results = benchmark_dataparallel(config)
        if dp_results:
            results.append(dp_results)
    
    # Print comparison
    print("\n" + "="*70)
    print("📈 PERFORMANCE COMPARISON")
    print("="*70)
    
    if len(results) >= 2:
        single_gpu_time = results[0]['avg_epoch_time']
        dp_time = results[1]['avg_epoch_time']
        
        if dp_time < single_gpu_time:
            speedup = single_gpu_time / dp_time
            print(f"🎉 DataParallel is {speedup:.2f}x FASTER than single GPU")
        else:
            slowdown = dp_time / single_gpu_time
            print(f"⚠️  DataParallel is {slowdown:.2f}x SLOWER than single GPU")
        
        print(f"\nDetailed Comparison:")
        for result in results:
            print(f"  {result['strategy']:12}: {result['avg_epoch_time']:6.2f}s/epoch, "
                  f"{result['peak_memory_gb']:5.2f}GB memory, "
                  f"{result['avg_gpu_util']:5.1f}% GPU util")
    
    # Save results
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    results_file = f"benchmark_results_{timestamp}.json"
    
    with open(results_file, 'w') as f:
        json.dump({
            'timestamp': timestamp,
            'config': config.__dict__,
            'system_info': {
                'gpu_count': gpu_count,
                'gpu_names': [torch.cuda.get_device_properties(i).name for i in range(gpu_count)],
                'pytorch_version': torch.__version__,
                'cuda_version': torch.version.cuda
            },
            'results': results
        }, f, indent=2)
    
    print(f"\n💾 Results saved to: {results_file}")
    
    # Recommendations
    print(f"\n💡 RECOMMENDATIONS:")
    if len(results) >= 2:
        if results[1]['avg_epoch_time'] < results[0]['avg_epoch_time']:
            print("✅ Use DataParallel for this model size - it provides speedup")
        else:
            print("⚠️  Consider single GPU or try:")
            print("   - Larger model (more parameters)")
            print("   - Larger batch size") 
            print("   - DistributedDataParallel (DDP) for better efficiency")
            print("   - Check if model is too small for multi-GPU overhead")
    
    print("🔍 For DDP benchmarks, run: python train_ddp.py")
    
    return results


if __name__ == "__main__":
    run_comprehensive_benchmark()
