"""Benchmarking utilities for rigorous timing and energy measurement.

This module provides JIT-aware timing, energy measurement via RAPL/nvidia-smi,
and statistical aggregation for ICML-quality benchmark reporting.

Key design decisions:
1. JIT warmup: First call compiles, subsequent calls measure execution time
2. Multiple repetitions with statistical reporting (median, IQR)
3. Synchronization barriers for accurate GPU timing
4. Energy integration via power sampling or direct energy counters
"""

import gc
import os
import subprocess
import time
import threading
from contextlib import contextmanager
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np


# =============================================================================
# TIMING UTILITIES
# =============================================================================

@dataclass
class TimingResult:
    """Results from a timed operation with statistical summary.

    All times are in seconds.

    Attributes:
        name: Descriptive name of the operation
        n_warmup: Number of warmup iterations (not included in stats)
        n_repeats: Number of timed repetitions
        times_s: Raw timing measurements (seconds)
        mean_s: Mean execution time
        median_s: Median execution time (robust to outliers)
        std_s: Standard deviation
        min_s: Minimum time
        max_s: Maximum time
        iqr_s: Interquartile range (75th - 25th percentile)
        jit_compile_time_s: Time for first (JIT compilation) run
    """
    name: str
    n_warmup: int
    n_repeats: int
    times_s: List[float] = field(default_factory=list)
    mean_s: float = 0.0
    median_s: float = 0.0
    std_s: float = 0.0
    min_s: float = 0.0
    max_s: float = 0.0
    iqr_s: float = 0.0
    jit_compile_time_s: float = 0.0

    def __post_init__(self):
        if self.times_s:
            self._compute_stats()

    def _compute_stats(self):
        """Compute statistical summary from raw times."""
        t = np.array(self.times_s)
        self.mean_s = float(np.mean(t))
        self.median_s = float(np.median(t))
        self.std_s = float(np.std(t, ddof=1)) if len(t) > 1 else 0.0
        self.min_s = float(np.min(t))
        self.max_s = float(np.max(t))
        q75, q25 = np.percentile(t, [75, 25])
        self.iqr_s = float(q75 - q25)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


def sync_jax_devices():
    """Block until all JAX device computations complete.

    This is critical for accurate GPU timing - without synchronization,
    timing may only measure kernel launch overhead, not actual compute.
    """
    try:
        import jax
        # block_until_ready on a simple computation forces sync
        jax.device_get(jax.numpy.zeros(1))
    except ImportError:
        pass


def time_function(
    fn: Callable,
    args: Tuple = (),
    kwargs: Optional[Dict] = None,
    n_warmup: int = 3,
    n_repeats: int = 10,
    name: str = "operation",
    sync_devices: bool = True,
    gc_collect: bool = True,
) -> Tuple[Any, TimingResult]:
    """Time a function with JIT warmup and statistical aggregation.

    Methodology:
    1. Run n_warmup iterations to compile JIT and warm caches
    2. Run n_repeats iterations, timing each one
    3. Synchronize GPU after each call if sync_devices=True
    4. Report median (robust) and other statistics

    Parameters
    ----------
    fn : Callable
        Function to time
    args : tuple
        Positional arguments for fn
    kwargs : dict, optional
        Keyword arguments for fn
    n_warmup : int
        Number of warmup iterations (default: 3)
    n_repeats : int
        Number of timed iterations (default: 10)
    name : str
        Name for reporting
    sync_devices : bool
        Whether to sync JAX devices after each call (default: True)
    gc_collect : bool
        Whether to run garbage collection before timing (default: True)

    Returns
    -------
    result : Any
        Return value from the last function call
    timing : TimingResult
        Timing statistics
    """
    kwargs = kwargs or {}
    result = None

    # Garbage collection before timing
    if gc_collect:
        gc.collect()

    # Warmup phase (includes JIT compilation)
    jit_time = None
    for i in range(n_warmup):
        if sync_devices:
            sync_jax_devices()

        t_start = time.perf_counter()
        result = fn(*args, **kwargs)
        if sync_devices:
            sync_jax_devices()
        t_end = time.perf_counter()

        if i == 0:
            jit_time = t_end - t_start

    # Timed phase
    times = []
    for _ in range(n_repeats):
        if sync_devices:
            sync_jax_devices()

        t_start = time.perf_counter()
        result = fn(*args, **kwargs)
        if sync_devices:
            sync_jax_devices()
        t_end = time.perf_counter()

        times.append(t_end - t_start)

    timing = TimingResult(
        name=name,
        n_warmup=n_warmup,
        n_repeats=n_repeats,
        times_s=times,
        jit_compile_time_s=jit_time or 0.0,
    )

    return result, timing


@contextmanager
def timed_block(name: str = "block", sync_devices: bool = True):
    """Context manager for timing a code block.

    Usage:
        with timed_block("training") as timer:
            # ... training code ...
        print(f"Took {timer.elapsed_s:.2f}s")
    """
    class Timer:
        elapsed_s: float = 0.0

    timer = Timer()

    if sync_devices:
        sync_jax_devices()

    start = time.perf_counter()
    try:
        yield timer
    finally:
        if sync_devices:
            sync_jax_devices()
        timer.elapsed_s = time.perf_counter() - start


# =============================================================================
# ENERGY MEASUREMENT
# =============================================================================

@dataclass
class EnergyResult:
    """Results from energy measurement.

    Attributes:
        source: Measurement source ("nvidia-smi", "rapl", "none")
        energy_joules: Total energy consumed in Joules
        duration_s: Measurement duration in seconds
        avg_power_watts: Average power (energy/duration)
        power_samples: Raw power samples if available
        sample_interval_s: Interval between power samples
    """
    source: str = "none"
    energy_joules: float = 0.0
    duration_s: float = 0.0
    avg_power_watts: float = 0.0
    power_samples: List[float] = field(default_factory=list)
    sample_interval_s: float = 0.1

    def to_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        # Don't include raw samples in serialization (too large)
        d["n_samples"] = len(self.power_samples)
        d.pop("power_samples")
        return d


class RAPLReader:
    """Read CPU energy from Intel RAPL interface (Linux only).

    RAPL provides energy counters in microjoules at:
    /sys/class/powercap/intel-rapl/intel-rapl:*/energy_uj

    The counter wraps around at max_energy_range_uj.
    """

    def __init__(self):
        self.available = False
        self.energy_paths = []
        self.max_energy = []

        # Find all RAPL domains
        rapl_base = Path("/sys/class/powercap/intel-rapl")
        if rapl_base.exists():
            for domain in rapl_base.glob("intel-rapl:*"):
                energy_path = domain / "energy_uj"
                max_path = domain / "max_energy_range_uj"
                if energy_path.exists():
                    self.energy_paths.append(energy_path)
                    try:
                        with open(max_path) as f:
                            self.max_energy.append(int(f.read().strip()))
                    except (IOError, ValueError):
                        self.max_energy.append(2**63)  # Large default

            self.available = len(self.energy_paths) > 0

    def read_energy_uj(self) -> int:
        """Read total energy in microjoules across all domains."""
        total = 0
        for path in self.energy_paths:
            try:
                with open(path) as f:
                    total += int(f.read().strip())
            except (IOError, ValueError):
                pass
        return total

    def read_energy_j(self) -> float:
        """Read total energy in Joules."""
        return self.read_energy_uj() / 1e6


class NvidiaSMIPowerSampler:
    """Sample GPU power using nvidia-smi in a background thread.

    This uses continuous polling to estimate energy consumption
    by integrating power over time.
    """

    def __init__(self, sample_interval_s: float = 0.1):
        self.sample_interval = sample_interval_s
        self.available = self._check_available()
        self._samples: List[float] = []
        self._timestamps: List[float] = []
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def _check_available(self) -> bool:
        """Check if nvidia-smi is available."""
        try:
            result = subprocess.run(
                ["nvidia-smi", "--query-gpu=power.draw", "--format=csv,noheader,nounits"],
                capture_output=True, text=True, timeout=5
            )
            return result.returncode == 0
        except (subprocess.TimeoutExpired, FileNotFoundError):
            return False

    def _read_power_watts(self) -> Optional[float]:
        """Read current GPU power draw in Watts."""
        try:
            result = subprocess.run(
                ["nvidia-smi", "--query-gpu=power.draw", "--format=csv,noheader,nounits"],
                capture_output=True, text=True, timeout=2
            )
            if result.returncode == 0:
                # Sum power from all GPUs
                total = 0.0
                for line in result.stdout.strip().split("\n"):
                    try:
                        total += float(line.strip())
                    except ValueError:
                        pass
                return total
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass
        return None

    def _sample_loop(self):
        """Background thread sampling loop."""
        while self._running:
            power = self._read_power_watts()
            if power is not None:
                self._samples.append(power)
                self._timestamps.append(time.perf_counter())
            time.sleep(self.sample_interval)

    def start(self):
        """Start power sampling in background."""
        self._samples = []
        self._timestamps = []
        self._running = True
        self._thread = threading.Thread(target=self._sample_loop, daemon=True)
        self._thread.start()

    def stop(self) -> Tuple[List[float], List[float]]:
        """Stop sampling and return (power_samples, timestamps)."""
        self._running = False
        if self._thread:
            self._thread.join(timeout=1.0)
        return self._samples.copy(), self._timestamps.copy()

    def compute_energy_joules(self) -> float:
        """Integrate power samples to get energy in Joules.

        Uses trapezoidal integration: E = integral(P dt)
        """
        if len(self._samples) < 2:
            return 0.0

        # Trapezoidal integration
        energy = 0.0
        for i in range(1, len(self._samples)):
            dt = self._timestamps[i] - self._timestamps[i-1]
            avg_power = (self._samples[i] + self._samples[i-1]) / 2
            energy += avg_power * dt

        return energy


class EnergyMonitor:
    """Unified interface for energy measurement from multiple sources.

    Supports:
    - RAPL (Intel CPU energy counters, Linux)
    - nvidia-smi (GPU power sampling)

    Usage:
        monitor = EnergyMonitor()
        monitor.start()
        # ... computation ...
        result = monitor.stop()
        print(f"Energy: {result.energy_joules:.1f} J")
    """

    def __init__(
        self,
        use_rapl: bool = True,
        use_nvidia: bool = True,
        nvidia_sample_interval: float = 0.1
    ):
        self.rapl = RAPLReader() if use_rapl else None
        self.nvidia = NvidiaSMIPowerSampler(nvidia_sample_interval) if use_nvidia else None

        self._rapl_start: int = 0
        self._start_time: float = 0.0

    @property
    def available_sources(self) -> List[str]:
        """List of available energy measurement sources."""
        sources = []
        if self.rapl and self.rapl.available:
            sources.append("rapl")
        if self.nvidia and self.nvidia.available:
            sources.append("nvidia-smi")
        return sources

    def start(self):
        """Start energy monitoring."""
        self._start_time = time.perf_counter()

        if self.rapl and self.rapl.available:
            self._rapl_start = self.rapl.read_energy_uj()

        if self.nvidia and self.nvidia.available:
            self.nvidia.start()

    def stop(self) -> EnergyResult:
        """Stop monitoring and return energy result."""
        duration = time.perf_counter() - self._start_time

        result = EnergyResult(
            duration_s=duration,
            sample_interval_s=self.nvidia.sample_interval if self.nvidia else 0.0
        )

        # Prefer GPU energy if available (more relevant for ML workloads)
        if self.nvidia and self.nvidia.available:
            samples, _ = self.nvidia.stop()
            result.source = "nvidia-smi"
            result.power_samples = samples
            result.energy_joules = self.nvidia.compute_energy_joules()
            if duration > 0:
                result.avg_power_watts = result.energy_joules / duration

        elif self.rapl and self.rapl.available:
            rapl_end = self.rapl.read_energy_uj()
            result.source = "rapl"
            # Handle counter wraparound
            energy_uj = rapl_end - self._rapl_start
            if energy_uj < 0:
                energy_uj += self.rapl.max_energy[0]
            result.energy_joules = energy_uj / 1e6
            if duration > 0:
                result.avg_power_watts = result.energy_joules / duration

        return result


@contextmanager
def energy_monitored(
    use_rapl: bool = True,
    use_nvidia: bool = True,
    nvidia_sample_interval: float = 0.1
):
    """Context manager for energy measurement.

    Usage:
        with energy_monitored() as monitor:
            # ... computation ...
        print(f"Energy: {monitor.result.energy_joules:.1f} J")
    """
    class MonitorWrapper:
        result: Optional[EnergyResult] = None

    wrapper = MonitorWrapper()
    monitor = EnergyMonitor(use_rapl, use_nvidia, nvidia_sample_interval)

    monitor.start()
    try:
        yield wrapper
    finally:
        wrapper.result = monitor.stop()


# =============================================================================
# COMBINED BENCHMARKING
# =============================================================================

@dataclass
class BenchmarkResult:
    """Combined timing and energy results for a benchmark."""
    name: str
    timing: TimingResult
    energy: Optional[EnergyResult] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "timing": self.timing.to_dict(),
            "energy": self.energy.to_dict() if self.energy else None,
            "metadata": self.metadata,
        }


def benchmark_function(
    fn: Callable,
    args: Tuple = (),
    kwargs: Optional[Dict] = None,
    name: str = "benchmark",
    n_warmup: int = 3,
    n_repeats: int = 10,
    measure_energy: bool = True,
    sync_devices: bool = True,
    metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, BenchmarkResult]:
    """Run a comprehensive benchmark with timing and energy measurement.

    This is the main entry point for benchmarking. It:
    1. Runs warmup iterations for JIT compilation
    2. Measures execution time with statistical aggregation
    3. Measures energy consumption (if available)

    Parameters
    ----------
    fn : Callable
        Function to benchmark
    args : tuple
        Positional arguments
    kwargs : dict, optional
        Keyword arguments
    name : str
        Benchmark name for reporting
    n_warmup : int
        Warmup iterations (default: 3)
    n_repeats : int
        Timed iterations (default: 10)
    measure_energy : bool
        Whether to measure energy (default: True)
    sync_devices : bool
        Sync JAX devices for accurate GPU timing (default: True)
    metadata : dict, optional
        Additional metadata to include in results

    Returns
    -------
    result : Any
        Return value from the function
    benchmark : BenchmarkResult
        Timing and energy statistics
    """
    kwargs = kwargs or {}
    metadata = metadata or {}

    energy_result = None

    if measure_energy:
        monitor = EnergyMonitor()
        monitor.start()

    result, timing = time_function(
        fn=fn,
        args=args,
        kwargs=kwargs,
        n_warmup=n_warmup,
        n_repeats=n_repeats,
        name=name,
        sync_devices=sync_devices,
    )

    if measure_energy:
        energy_result = monitor.stop()

    benchmark = BenchmarkResult(
        name=name,
        timing=timing,
        energy=energy_result,
        metadata=metadata,
    )

    return result, benchmark


# =============================================================================
# REPORTING UTILITIES
# =============================================================================

def format_timing_table(results: List[TimingResult], include_jit: bool = True) -> str:
    """Format timing results as an ASCII table."""
    lines = []

    # Header
    if include_jit:
        header = f"{'Operation':<30} {'Median (s)':<12} {'Mean (s)':<12} {'Std (s)':<10} {'JIT (s)':<10}"
    else:
        header = f"{'Operation':<30} {'Median (s)':<12} {'Mean (s)':<12} {'Std (s)':<10}"

    lines.append("=" * len(header))
    lines.append(header)
    lines.append("=" * len(header))

    for r in results:
        if include_jit:
            line = f"{r.name:<30} {r.median_s:<12.4f} {r.mean_s:<12.4f} {r.std_s:<10.4f} {r.jit_compile_time_s:<10.4f}"
        else:
            line = f"{r.name:<30} {r.median_s:<12.4f} {r.mean_s:<12.4f} {r.std_s:<10.4f}"
        lines.append(line)

    lines.append("=" * len(header))
    return "\n".join(lines)


def format_benchmark_latex(
    results: List[BenchmarkResult],
    caption: str = "Benchmark results",
    label: str = "tab:benchmark"
) -> str:
    """Format benchmark results as a LaTeX table."""
    lines = [
        r"\begin{table}[t]",
        r"\centering",
        r"\small",
        r"\begin{tabular}{@{}lrrrrr@{}}",
        r"\toprule",
        r"Operation & Median (s) & Mean (s) & Std (s) & JIT (s) & Energy (J) \\",
        r"\midrule",
    ]

    for r in results:
        energy = f"{r.energy.energy_joules:.1f}" if r.energy and r.energy.energy_joules > 0 else "---"
        lines.append(
            f"{r.name} & {r.timing.median_s:.4f} & {r.timing.mean_s:.4f} & "
            f"{r.timing.std_s:.4f} & {r.timing.jit_compile_time_s:.4f} & {energy} \\\\"
        )

    lines.extend([
        r"\bottomrule",
        r"\end{tabular}",
        f"\\caption{{{caption}}}",
        f"\\label{{{label}}}",
        r"\end{table}",
    ])

    return "\n".join(lines)
