"""
Model Performance Benchmark Testing

Statistics for each model:
- Parameter count
- Inference time (single-step/multi-step)
- Training speed (steps/sec)
- Memory usage

Output to model_benchmark.md file
"""

import os
import sys
import time
import torch
import numpy as np
from datetime import datetime
from typing import Dict, List

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import create_model, ModelConfig


def count_parameters(model: torch.nn.Module) -> Dict:
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'total_params_M': total_params / 1e6,
    }


def measure_inference_time(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device,
    num_warmup: int = 10,
    num_runs: int = 100
) -> Dict:
    """Measure single-step inference time"""
    model.eval()
    model.to(device)

    # Create input
    x = torch.randn(*input_shape).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(num_warmup):
            _ = model(x)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Timing
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.perf_counter()
            _ = model(x)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)  # ms

    return {
        'inference_time_ms': np.mean(times),
        'inference_time_std': np.std(times),
        'throughput_samples_per_sec': input_shape[0] / (np.mean(times) / 1000),
    }


def measure_rollout_time(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device,
    num_steps: int = 100,
    num_runs: int = 5
) -> Dict:
    """Measure multi-step rollout time"""
    model.eval()
    model.to(device)

    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            x = torch.randn(*input_shape).to(device)

            if device.type == 'cuda':
                torch.cuda.synchronize()

            start = time.perf_counter()
            for step in range(num_steps):
                outputs = model(x)
                x = outputs[0] if isinstance(outputs, tuple) else outputs
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()

            times.append((end - start) * 1000)

    total_time = np.mean(times)
    return {
        'rollout_total_ms': total_time,
        'rollout_per_step_ms': total_time / num_steps,
        'rollout_steps_per_sec': num_steps / (total_time / 1000),
    }


def measure_memory(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device
) -> Dict:
    """Measure GPU memory usage"""
    if device.type != 'cuda':
        return {'memory_MB': 0, 'memory_peak_MB': 0}

    model.to(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()

    x = torch.randn(*input_shape).to(device)

    # Forward pass
    with torch.no_grad():
        _ = model(x)

    current = torch.cuda.memory_allocated(device) / 1e6
    peak = torch.cuda.max_memory_allocated(device) / 1e6

    return {
        'memory_MB': current,
        'memory_peak_MB': peak,
    }


def benchmark_model(
    model_config: ModelConfig,
    dataset_type: str,
    input_shape: tuple,
    device: torch.device
) -> Dict:
    """Complete benchmark test for a single model"""
    try:
        model = create_model(model_config, dataset_type)
    except Exception as e:
        return {'error': str(e)}

    results = {
        'model_type': model_config.model_type,
    }

    # Parameter count
    param_stats = count_parameters(model)
    results.update(param_stats)

    # Inference time
    try:
        inference_stats = measure_inference_time(model, input_shape, device)
        results.update(inference_stats)
    except Exception as e:
        results['inference_error'] = str(e)

    # Rollout time
    try:
        rollout_stats = measure_rollout_time(model, input_shape, device, num_steps=50)
        results.update(rollout_stats)
    except Exception as e:
        results['rollout_error'] = str(e)

    # Memory
    try:
        if device.type == 'cuda':
            mem_stats = measure_memory(model, input_shape, device)
            results.update(mem_stats)
    except Exception as e:
        results['memory_error'] = str(e)

    return results


def generate_benchmark_report(
    results: List[Dict],
    output_path: str,
    dataset_type: str
):
    """Generate benchmark test report"""
    md_content = f"""# Model Performance Benchmark

Dataset: {dataset_type}
Test Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

---

## Parameter Count Statistics

| Model | Total Params | Trainable Params | Params (M) |
|-------|--------------|------------------|------------|
"""
    for r in results:
        if 'error' not in r:
            md_content += f"| {r['model_type']} | {r['total_params']:,} | {r['trainable_params']:,} | {r['total_params_M']:.2f} |\n"

    md_content += f"""
---

## Inference Speed

| Model | Single-Step (ms) | Throughput (samples/s) | Rollout (steps/s) |
|-------|------------------|------------------------|-------------------|
"""
    for r in results:
        if 'error' not in r and 'inference_time_ms' in r:
            inf_time = f"{r['inference_time_ms']:.2f}±{r.get('inference_time_std', 0):.2f}"
            throughput = f"{r.get('throughput_samples_per_sec', 0):.1f}"
            rollout_speed = f"{r.get('rollout_steps_per_sec', 0):.1f}"
            md_content += f"| {r['model_type']} | {inf_time} | {throughput} | {rollout_speed} |\n"

    if any('memory_MB' in r for r in results):
        md_content += f"""
---

## GPU Memory Usage

| Model | Current (MB) | Peak (MB) |
|-------|--------------|-----------|
"""
        for r in results:
            if 'memory_MB' in r:
                md_content += f"| {r['model_type']} | {r['memory_MB']:.1f} | {r['memory_peak_MB']:.1f} |\n"

    md_content += f"""
---

## Test Configuration

- Warmup iterations: 10
- Timing iterations: 100
- Rollout steps: 50

"""

    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(md_content)

    print(f"Benchmark report saved to: {output_path}")


def run_traffic_flow_benchmark(gpu_id: int = 0):
    """Traffic flow model benchmark test"""
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    input_shape = (16, 2, 256)  # batch, channels, length
    dataset_type = 'traffic_flow'

    configs = [
        ModelConfig(model_type='FluxNet_N_1D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=11),
        ModelConfig(model_type='FluxNet_P_1D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=11),
        ModelConfig(model_type='FluxNet_L_1D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=11, lower_bound=0.0),
        ModelConfig(model_type='FluxNet_D_1D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=11, lower_bound=0.0, upper_bound=1.0),
        ModelConfig(model_type='FluxNet_U_1D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=11, upper_bound=1.0),
        ModelConfig(model_type='CNN_Baseline_1D', base_channels=64, num_blocks=6, kernel_size=5, prediction_mode='residual'),
        ModelConfig(model_type='FNO_1D', modes=16, width=64, num_layers=4, prediction_mode='residual'),
    ]

    results = []
    for config in configs:
        print(f"Testing: {config.model_type}")
        result = benchmark_model(config, dataset_type, input_shape, device)
        results.append(result)
        print(f"  Params: {result.get('total_params', 'N/A'):,}")

    output_path = os.path.join(project_root, 'results', 'traffic_flow', 'model_benchmark.md')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    generate_benchmark_report(results, output_path, dataset_type)


def run_shallow_water_benchmark(gpu_id: int = 0):
    """Shallow water equation model benchmark test"""
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    input_shape = (16, 3, 64, 64)  # batch, channels, H, W
    dataset_type = 'shallow_water'

    configs = [
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='LAP', lower_bound=0.0),
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='PPP'),
        ModelConfig(model_type='FNO_SW', modes1=16, modes2=16, width=64, num_layers=4),
        ModelConfig(model_type='FluxNet_SW_Baseline', base_channels=64, num_blocks=6, kernel_size=5, prediction_mode='residual'),
    ]

    results = []
    for config in configs:
        print(f"Testing: {config.model_type} ({config.head_config if hasattr(config, 'head_config') else ''})")
        result = benchmark_model(config, dataset_type, input_shape, device)
        results.append(result)
        print(f"  Params: {result.get('total_params', 'N/A'):,}")

    output_path = os.path.join(project_root, 'results', 'shallow_water', 'model_benchmark.md')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    generate_benchmark_report(results, output_path, dataset_type)


def run_spinodal_benchmark(gpu_id: int = 0):
    """Spinodal decomposition model benchmark test"""
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    input_shape = (16, 1, 128, 128)  # batch, channels, H, W
    dataset_type = 'spinodal_decomposition'

    configs = [
        ModelConfig(model_type='FluxNet_D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, lower_bound=0.0, upper_bound=1.0),
        ModelConfig(model_type='FluxNet_D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=9, lower_bound=0.0, upper_bound=1.0),
        ModelConfig(model_type='FluxNet_D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=15, lower_bound=0.0, upper_bound=1.0),
    ]

    results = []
    for config in configs:
        print(f"Testing: {config.model_type} (neighborhood={config.neighborhood_size})")
        result = benchmark_model(config, dataset_type, input_shape, device)
        result['model_type'] = f"{config.model_type}_n{config.neighborhood_size}"
        results.append(result)
        print(f"  Params: {result.get('total_params', 'N/A'):,}")

    output_path = os.path.join(project_root, 'results', 'spinodal_decomposition', 'model_benchmark.md')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    generate_benchmark_report(results, output_path, dataset_type)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Model Benchmark')
    parser.add_argument('--dataset', type=str, default='all',
                        choices=['traffic_flow', 'shallow_water', 'spinodal', 'all'])
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()

    if args.dataset in ['traffic_flow', 'all']:
        print("\n" + "="*60)
        print("Traffic Flow Model Benchmark")
        print("="*60)
        run_traffic_flow_benchmark(args.gpu)

    if args.dataset in ['shallow_water', 'all']:
        print("\n" + "="*60)
        print("Shallow Water Equation Model Benchmark")
        print("="*60)
        run_shallow_water_benchmark(args.gpu)

    if args.dataset in ['spinodal', 'all']:
        print("\n" + "="*60)
        print("Spinodal Decomposition Model Benchmark")
        print("="*60)
        run_spinodal_benchmark(args.gpu)

    print("\nBenchmark test complete!")
