#!/usr/bin/env python3
"""System benchmark utility for extracting hardware and software information.

This script extracts and logs system configuration for reproducibility reporting
in scientific publications. Can be run standalone or imported as a module.

Usage:
    python sys_bench.py                    # Print to stdout
    python sys_bench.py -o sys_info.json   # Save to JSON file
    python sys_bench.py --latex            # Output LaTeX-formatted table
"""

import argparse
import json
import os
import platform
import subprocess
import sys
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any


@dataclass
class CPUInfo:
    """CPU hardware information."""
    model: str = "Unknown"
    physical_cores: int = 0
    logical_cores: int = 0
    architecture: str = "Unknown"


@dataclass
class GPUInfo:
    """GPU hardware information."""
    name: str = "Unknown"
    driver_version: str = "Unknown"
    cuda_version: str = "Unknown"
    memory_total_mb: int = 0
    compute_capability: str = "Unknown"


@dataclass
class MemoryInfo:
    """System memory information."""
    total_gb: float = 0.0
    available_gb: float = 0.0


@dataclass
class SoftwareVersions:
    """Software version information."""
    python: str = "Unknown"
    jax: str = "Unknown"
    jaxlib: str = "Unknown"
    flax: str = "Unknown"
    diffrax: str = "Unknown"
    optax: str = "Unknown"
    numpy: str = "Unknown"
    scipy: str = "Unknown"


@dataclass
class GitInfo:
    """Git repository information."""
    cfd_commit: str = "Unknown"
    cfd_branch: str = "Unknown"
    cfd_dirty: bool = False
    tfmpe_commit: str = "Unknown"
    tfmpe_branch: str = "Unknown"
    tfmpe_dirty: bool = False


@dataclass
class SystemInfo:
    """Complete system information for benchmarking."""
    timestamp: str = ""
    hostname: str = ""
    os_name: str = ""
    os_version: str = ""
    kernel: str = ""
    cpu: CPUInfo = field(default_factory=CPUInfo)
    gpu: GPUInfo = field(default_factory=GPUInfo)
    memory: MemoryInfo = field(default_factory=MemoryInfo)
    software: SoftwareVersions = field(default_factory=SoftwareVersions)
    git: GitInfo = field(default_factory=GitInfo)
    jax_backend: str = "Unknown"
    jax_devices: List[str] = field(default_factory=list)


def run_command(cmd: List[str], default: str = "Unknown") -> str:
    """Run a shell command and return stdout, or default on failure."""
    try:
        result = subprocess.run(
            cmd, capture_output=True, text=True, timeout=10
        )
        if result.returncode == 0:
            return result.stdout.strip()
        return default
    except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
        return default


def get_cpu_info() -> CPUInfo:
    """Extract CPU information from system."""
    info = CPUInfo()

    # Architecture
    info.architecture = platform.machine()

    # Logical cores
    info.logical_cores = os.cpu_count() or 0

    # Try to get detailed CPU info
    if platform.system() == "Linux":
        # Parse /proc/cpuinfo for model name
        try:
            with open("/proc/cpuinfo", "r") as f:
                for line in f:
                    if line.startswith("model name"):
                        info.model = line.split(":")[1].strip()
                        break
        except (IOError, IndexError):
            pass

        # Physical cores from lscpu
        lscpu_out = run_command(["lscpu"])
        for line in lscpu_out.split("\n"):
            if "Core(s) per socket:" in line:
                try:
                    cores_per_socket = int(line.split(":")[1].strip())
                except (ValueError, IndexError):
                    cores_per_socket = 0
            if "Socket(s):" in line:
                try:
                    sockets = int(line.split(":")[1].strip())
                except (ValueError, IndexError):
                    sockets = 1
        try:
            info.physical_cores = cores_per_socket * sockets
        except NameError:
            info.physical_cores = info.logical_cores

    elif platform.system() == "Darwin":
        info.model = run_command(["sysctl", "-n", "machdep.cpu.brand_string"])
        phys = run_command(["sysctl", "-n", "hw.physicalcpu"])
        try:
            info.physical_cores = int(phys)
        except ValueError:
            info.physical_cores = info.logical_cores
    else:
        info.model = platform.processor() or "Unknown"
        info.physical_cores = info.logical_cores

    return info


def get_gpu_info() -> GPUInfo:
    """Extract GPU information using nvidia-smi."""
    info = GPUInfo()

    # Check if nvidia-smi is available
    nvidia_smi = run_command(["which", "nvidia-smi"])
    if nvidia_smi == "Unknown" or not nvidia_smi:
        return info

    # GPU name
    info.name = run_command([
        "nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"
    ])

    # Driver version
    info.driver_version = run_command([
        "nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"
    ])

    # CUDA version
    cuda_out = run_command(["nvidia-smi"])
    for line in cuda_out.split("\n"):
        if "CUDA Version:" in line:
            try:
                info.cuda_version = line.split("CUDA Version:")[1].split()[0].strip()
            except (IndexError, ValueError):
                pass
            break

    # Memory
    mem_str = run_command([
        "nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"
    ])
    try:
        info.memory_total_mb = int(mem_str.split()[0])
    except (ValueError, IndexError):
        pass

    # Compute capability
    # nvidia-smi doesn't directly report this, try deviceQuery if available
    # or use a lookup table based on GPU name
    gpu_compute_caps = {
        "A100": "8.0", "A10": "8.6", "A6000": "8.6",
        "V100": "7.0", "T4": "7.5", "RTX 3090": "8.6",
        "RTX 3080": "8.6", "RTX 4090": "8.9", "RTX 4080": "8.9",
        "H100": "9.0", "L40": "8.9",
    }
    for gpu_key, cap in gpu_compute_caps.items():
        if gpu_key.lower() in info.name.lower():
            info.compute_capability = cap
            break

    return info


def get_memory_info() -> MemoryInfo:
    """Extract system memory information."""
    info = MemoryInfo()

    try:
        import psutil
        mem = psutil.virtual_memory()
        info.total_gb = mem.total / (1024**3)
        info.available_gb = mem.available / (1024**3)
    except ImportError:
        # Fallback for Linux
        if platform.system() == "Linux":
            try:
                with open("/proc/meminfo", "r") as f:
                    for line in f:
                        if line.startswith("MemTotal:"):
                            kb = int(line.split()[1])
                            info.total_gb = kb / (1024**2)
                        elif line.startswith("MemAvailable:"):
                            kb = int(line.split()[1])
                            info.available_gb = kb / (1024**2)
            except (IOError, ValueError, IndexError):
                pass

    return info


def get_software_versions() -> SoftwareVersions:
    """Extract versions of key Python packages."""
    versions = SoftwareVersions()
    versions.python = platform.python_version()

    # Try importing each package and getting version
    packages = {
        "jax": "jax",
        "jaxlib": "jaxlib",
        "flax": "flax",
        "diffrax": "diffrax",
        "optax": "optax",
        "numpy": "numpy",
        "scipy": "scipy",
    }

    for attr, pkg_name in packages.items():
        try:
            pkg = __import__(pkg_name)
            setattr(versions, attr, getattr(pkg, "__version__", "Unknown"))
        except ImportError:
            pass

    return versions


def get_git_info(cfd_path: Optional[str] = None, tfmpe_path: Optional[str] = None) -> GitInfo:
    """Extract git commit information for repositories."""
    info = GitInfo()

    # CFD repo
    if cfd_path is None:
        cfd_path = Path(__file__).parent

    cwd = os.getcwd()
    try:
        os.chdir(cfd_path)
        info.cfd_commit = run_command(["git", "rev-parse", "HEAD"])[:12]
        info.cfd_branch = run_command(["git", "rev-parse", "--abbrev-ref", "HEAD"])
        status = run_command(["git", "status", "--porcelain"])
        info.cfd_dirty = bool(status and status != "Unknown")
    except Exception:
        pass
    finally:
        os.chdir(cwd)

    # TFMPE repo
    if tfmpe_path is None:
        tfmpe_path = Path(__file__).parent.parent / "tfmpe"

    if Path(tfmpe_path).exists():
        try:
            os.chdir(tfmpe_path)
            info.tfmpe_commit = run_command(["git", "rev-parse", "HEAD"])[:12]
            info.tfmpe_branch = run_command(["git", "rev-parse", "--abbrev-ref", "HEAD"])
            status = run_command(["git", "status", "--porcelain"])
            info.tfmpe_dirty = bool(status and status != "Unknown")
        except Exception:
            pass
        finally:
            os.chdir(cwd)

    return info


def get_jax_info() -> tuple:
    """Get JAX backend and device information."""
    backend = "Unknown"
    devices = []

    try:
        import jax
        backend = str(jax.default_backend())
        devices = [str(d) for d in jax.devices()]
    except ImportError:
        pass

    return backend, devices


def collect_system_info(
    cfd_path: Optional[str] = None,
    tfmpe_path: Optional[str] = None
) -> SystemInfo:
    """Collect all system information into a single dataclass."""
    jax_backend, jax_devices = get_jax_info()

    return SystemInfo(
        timestamp=datetime.now().isoformat(),
        hostname=platform.node(),
        os_name=platform.system(),
        os_version=platform.version(),
        kernel=platform.release(),
        cpu=get_cpu_info(),
        gpu=get_gpu_info(),
        memory=get_memory_info(),
        software=get_software_versions(),
        git=get_git_info(cfd_path, tfmpe_path),
        jax_backend=jax_backend,
        jax_devices=jax_devices,
    )


def to_dict(info: SystemInfo) -> Dict[str, Any]:
    """Convert SystemInfo to nested dictionary."""
    return asdict(info)


def to_json(info: SystemInfo, indent: int = 2) -> str:
    """Convert SystemInfo to JSON string."""
    return json.dumps(to_dict(info), indent=indent)


def to_latex_table(info: SystemInfo) -> str:
    """Generate LaTeX table rows for the system info."""
    lines = [
        r"\begin{tabular}{@{}ll@{}}",
        r"\toprule",
        r"\textbf{Component} & \textbf{Specification} \\",
        r"\midrule",
        f"CPU & {info.cpu.model} ({info.cpu.physical_cores} cores) \\\\",
        f"GPU & {info.gpu.name} ({info.gpu.memory_total_mb} MB VRAM) \\\\",
        f"RAM & {info.memory.total_gb:.1f} GB \\\\",
        f"OS & {info.os_name} {info.kernel} \\\\",
        r"\midrule",
        f"Python & {info.software.python} \\\\",
        f"JAX & {info.software.jax} (backend: {info.jax_backend}) \\\\",
        f"Flax & {info.software.flax} \\\\",
        f"Diffrax & {info.software.diffrax} \\\\",
        r"\midrule",
        f"CFD commit & \\texttt{{{info.git.cfd_commit}}}{'*' if info.git.cfd_dirty else ''} \\\\",
        f"TFMPE commit & \\texttt{{{info.git.tfmpe_commit}}}{'*' if info.git.tfmpe_dirty else ''} \\\\",
        r"\bottomrule",
        r"\end{tabular}",
    ]
    return "\n".join(lines)


def print_summary(info: SystemInfo) -> None:
    """Print a human-readable summary to stdout."""
    print("=" * 60)
    print("SYSTEM BENCHMARK INFORMATION")
    print("=" * 60)
    print(f"Timestamp: {info.timestamp}")
    print(f"Hostname:  {info.hostname}")
    print()

    print("--- Hardware ---")
    print(f"CPU:       {info.cpu.model}")
    print(f"           {info.cpu.physical_cores} physical / {info.cpu.logical_cores} logical cores")
    print(f"GPU:       {info.gpu.name}")
    print(f"           {info.gpu.memory_total_mb} MB VRAM, CUDA {info.gpu.cuda_version}")
    print(f"           Driver {info.gpu.driver_version}, Compute {info.gpu.compute_capability}")
    print(f"RAM:       {info.memory.total_gb:.1f} GB total, {info.memory.available_gb:.1f} GB available")
    print(f"OS:        {info.os_name} {info.os_version}")
    print(f"Kernel:    {info.kernel}")
    print()

    print("--- Software ---")
    print(f"Python:    {info.software.python}")
    print(f"JAX:       {info.software.jax} (backend: {info.jax_backend})")
    print(f"Jaxlib:    {info.software.jaxlib}")
    print(f"Flax:      {info.software.flax}")
    print(f"Diffrax:   {info.software.diffrax}")
    print(f"Optax:     {info.software.optax}")
    print(f"NumPy:     {info.software.numpy}")
    print(f"SciPy:     {info.software.scipy}")
    print()

    print("--- JAX Devices ---")
    for i, dev in enumerate(info.jax_devices):
        print(f"  [{i}] {dev}")
    print()

    print("--- Git Information ---")
    print(f"CFD repo:   {info.git.cfd_commit} ({info.git.cfd_branch})" +
          (" [dirty]" if info.git.cfd_dirty else ""))
    print(f"TFMPE repo: {info.git.tfmpe_commit} ({info.git.tfmpe_branch})" +
          (" [dirty]" if info.git.tfmpe_dirty else ""))
    print("=" * 60)


def main():
    parser = argparse.ArgumentParser(
        description="Extract system information for benchmark reproducibility"
    )
    parser.add_argument(
        "-o", "--output",
        type=str,
        default=None,
        help="Output file path for JSON (default: stdout)"
    )
    parser.add_argument(
        "--latex",
        action="store_true",
        help="Output LaTeX-formatted table"
    )
    parser.add_argument(
        "--json",
        action="store_true",
        help="Output raw JSON"
    )
    parser.add_argument(
        "--cfd-path",
        type=str,
        default=None,
        help="Path to CFD repository"
    )
    parser.add_argument(
        "--tfmpe-path",
        type=str,
        default=None,
        help="Path to TFMPE repository"
    )

    args = parser.parse_args()

    # Collect system info
    info = collect_system_info(
        cfd_path=args.cfd_path,
        tfmpe_path=args.tfmpe_path
    )

    # Output based on format
    if args.output:
        with open(args.output, "w") as f:
            json.dump(to_dict(info), f, indent=2)
        print(f"System info saved to {args.output}")
    elif args.latex:
        print(to_latex_table(info))
    elif args.json:
        print(to_json(info))
    else:
        print_summary(info)


if __name__ == "__main__":
    main()
