"""
Shallow Water Equation - Model Performance Benchmark Testing

Statistics for each model:
- Parameter count
- Inference time (single-step/multi-step)
- Training speed (steps/sec)
- Memory usage

Output to model_benchmark.md file
"""

import os
import sys
import time
import torch
import numpy as np
from datetime import datetime
from typing import Dict, List

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import create_model, ModelConfig


def count_parameters(model: torch.nn.Module) -> Dict:
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'total_params_M': total_params / 1e6,
    }


def measure_inference_time(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device,
    num_warmup: int = 10,
    num_runs: int = 100
) -> Dict:
    """Measure single-step inference time"""
    model.eval()
    model.to(device)

    x = torch.randn(*input_shape).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(num_warmup):
            _ = model(x)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Timing
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.perf_counter()
            _ = model(x)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()
            times.append((end - start) * 1000)  # ms

    return {
        'inference_time_ms': np.mean(times),
        'inference_time_std': np.std(times),
        'throughput_samples_per_sec': input_shape[0] / (np.mean(times) / 1000),
    }


def measure_rollout_time(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device,
    num_steps: int = 100,
    num_runs: int = 5
) -> Dict:
    """Measure multi-step rollout time"""
    model.eval()
    model.to(device)

    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            x = torch.randn(*input_shape).to(device)

            if device.type == 'cuda':
                torch.cuda.synchronize()

            start = time.perf_counter()
            for step in range(num_steps):
                outputs = model(x)
                x = outputs[0] if isinstance(outputs, tuple) else outputs
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()

            times.append((end - start) * 1000)

    total_time = np.mean(times)
    return {
        'rollout_total_ms': total_time,
        'rollout_per_step_ms': total_time / num_steps,
        'rollout_steps_per_sec': num_steps / (total_time / 1000),
    }


def measure_memory(
    model: torch.nn.Module,
    input_shape: tuple,
    device: torch.device
) -> Dict:
    """Measure GPU memory usage"""
    if device.type != 'cuda':
        return {'memory_MB': 0, 'memory_peak_MB': 0}

    model.to(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()

    x = torch.randn(*input_shape).to(device)

    with torch.no_grad():
        _ = model(x)

    current = torch.cuda.memory_allocated(device) / 1e6
    peak = torch.cuda.max_memory_allocated(device) / 1e6

    return {
        'memory_MB': current,
        'memory_peak_MB': peak,
    }


def benchmark_model(
    model_config: ModelConfig,
    dataset_type: str,
    input_shape: tuple,
    device: torch.device
) -> Dict:
    """Complete benchmark test for a single model"""
    try:
        model = create_model(model_config, dataset_type)
    except Exception as e:
        return {'error': str(e)}

    results = {
        'model_type': model_config.model_type,
        'head_config': getattr(model_config, 'head_config', ''),
    }

    # Parameter count
    param_stats = count_parameters(model)
    results.update(param_stats)

    # Inference time
    try:
        inference_stats = measure_inference_time(model, input_shape, device)
        results.update(inference_stats)
    except Exception as e:
        results['inference_error'] = str(e)

    # Rollout time
    try:
        rollout_stats = measure_rollout_time(model, input_shape, device, num_steps=50)
        results.update(rollout_stats)
    except Exception as e:
        results['rollout_error'] = str(e)

    # Memory
    try:
        if device.type == 'cuda':
            mem_stats = measure_memory(model, input_shape, device)
            results.update(mem_stats)
    except Exception as e:
        results['memory_error'] = str(e)

    return results


def generate_benchmark_report(
    results: List[Dict],
    output_path: str,
    dataset_type: str
):
    """Generate benchmark test report"""
    md_content = f"""# Shallow Water Equation Model Performance Benchmark

Dataset: {dataset_type}
Test Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

---

## Parameter Count Statistics

| Model | Head Config | Total Params | Trainable Params | Params (M) |
|-------|-------------|--------------|------------------|------------|
"""
    for r in results:
        if 'error' not in r:
            head = r.get('head_config', '')
            md_content += f"| {r['model_type']} | {head} | {r['total_params']:,} | {r['trainable_params']:,} | {r['total_params_M']:.2f} |\n"

    md_content += f"""
---

## Inference Speed

| Model | Head Config | Single-Step (ms) | Throughput (samples/s) | Rollout (steps/s) |
|-------|-------------|------------------|------------------------|-------------------|
"""
    for r in results:
        if 'error' not in r and 'inference_time_ms' in r:
            head = r.get('head_config', '')
            inf_time = f"{r['inference_time_ms']:.2f}±{r.get('inference_time_std', 0):.2f}"
            throughput = f"{r.get('throughput_samples_per_sec', 0):.1f}"
            rollout_speed = f"{r.get('rollout_steps_per_sec', 0):.1f}"
            md_content += f"| {r['model_type']} | {head} | {inf_time} | {throughput} | {rollout_speed} |\n"

    if any('memory_MB' in r for r in results):
        md_content += f"""
---

## GPU Memory Usage

| Model | Head Config | Current (MB) | Peak (MB) |
|-------|-------------|--------------|-----------|
"""
        for r in results:
            if 'memory_MB' in r:
                head = r.get('head_config', '')
                md_content += f"| {r['model_type']} | {head} | {r['memory_MB']:.1f} | {r['memory_peak_MB']:.1f} |\n"

    md_content += f"""
---

## Test Configuration

- Input shape: (batch=16, channels=3, H=64, W=64)
- Warmup iterations: 10
- Timing iterations: 100
- Rollout steps: 50

"""

    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(md_content)

    print(f"Benchmark report saved to: {output_path}")


def main():
    """Shallow water equation model benchmark test"""
    import argparse
    parser = argparse.ArgumentParser(description='Shallow Water Model Benchmark')
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    input_shape = (16, 3, 64, 64)  # batch, channels, H, W
    dataset_type = 'shallow_water'

    # Model configurations for benchmark testing
    configs = [
        # Our method (FluxNet-LAP)
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='LAP', lower_bound=0.0),
        # Ablation models
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='PPP'),
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='LPP', lower_bound=0.0),
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='PAP', lower_bound=0.0),
        ModelConfig(model_type='FluxNet_SW_2D', base_channels=64, num_blocks=6, kernel_size=5, neighborhood_size=5, head_config='LAP_no_gate', lower_bound=0.0),
        # Baseline models
        ModelConfig(model_type='FluxNet_SW_Baseline', base_channels=64, num_blocks=6, kernel_size=5, prediction_mode='residual'),
        ModelConfig(model_type='FNO_SW', modes=16, width=64, num_layers=4),
        ModelConfig(model_type='FNO_SW_Proj', modes=16, width=64, num_layers=4, projection_mode='box_mass', prediction_mode='residual'),
        ModelConfig(model_type='FNO_FluxLAP', modes=16, width=64, num_layers=4, neighborhood_size=5, lower_bound=0.0),
    ]

    results = []
    for config in configs:
        model_name = f"{config.model_type}"
        if hasattr(config, 'head_config') and config.head_config:
            model_name += f" ({config.head_config})"
        print(f"Testing: {model_name}")
        result = benchmark_model(config, dataset_type, input_shape, device)
        results.append(result)
        if 'error' in result:
            print(f"  Error: {result['error']}")
        else:
            print(f"  Params: {result.get('total_params', 'N/A'):,}")

    output_path = "/home/ml4pf/zshlan/FluxNet/results/shallow_water/model_benchmark.md"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    generate_benchmark_report(results, output_path, dataset_type)

    print("\nBenchmark test complete!")


if __name__ == "__main__":
    main()
