from __future__ import annotations

import argparse
import csv
import gc
import json
import math
import os
import shutil
import subprocess
import sys
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import torch

from perf_utils import RssSampler, measure_cuda_peak_memory, nvtx_range, run_ncu_metrics, slugify, sync_cuda


def _ensure_task_path(repo_root: Path) -> None:
    task_dir = repo_root / "Compare_RNN" / "task"
    if str(task_dir) not in sys.path:
        sys.path.insert(0, str(task_dir))


def _parse_int_list(value: str) -> List[int]:
    items = [item.strip() for item in str(value).split(",") if item.strip()]
    out: List[int] = []
    for item in items:
        out.append(int(item))
    unique: List[int] = []
    for v in out:
        if v not in unique:
            unique.append(v)
    return unique


def _tile_to_length(arr: np.ndarray, length: int) -> np.ndarray:
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D array (N,C,T); got {arr.shape}")
    base = int(arr.shape[2])
    length = int(length)
    if length <= base:
        return arr[:, :, :length]
    reps = int(math.ceil(length / base))
    tiled = np.tile(arr, (1, 1, reps))
    return tiled[:, :, :length]


def _method_palette(count: int) -> List[str]:
    base = [
        "#0072B2",
        "#D55E00",
        "#009E73",
        "#CC79A7",
        "#E69F00",
        "#56B4E9",
        "#F0E442",
        "#999999",
    ]
    return [base[idx % len(base)] for idx in range(max(0, count))]


def _scale_bytes(values: np.ndarray) -> Tuple[np.ndarray, str]:
    finite = values[np.isfinite(values)]
    if finite.size == 0:
        return values, "B"
    max_val = float(np.max(finite))
    if max_val < 1024.0:
        return values, "B"
    if max_val < 1024.0**2:
        return values / 1024.0, "KB"
    if max_val < 1024.0**3:
        return values / (1024.0**2), "MB"
    return values / (1024.0**3), "GB"


def _scale_time_sec(values: np.ndarray) -> Tuple[np.ndarray, str]:
    finite = values[np.isfinite(values)]
    if finite.size == 0:
        return values, "s"
    max_val = float(np.max(finite))
    if max_val < 1e-3:
        return values * 1e6, "us"
    if max_val < 1.0:
        return values * 1e3, "ms"
    return values, "s"


def _scale_flops(values: np.ndarray) -> Tuple[np.ndarray, str]:
    finite = values[np.isfinite(values)]
    if finite.size == 0:
        return values, "FLOPs"
    max_val = float(np.max(finite))
    if max_val < 1e6:
        return values, "FLOPs"
    if max_val < 1e9:
        return values / 1e6, "MFLOPs"
    if max_val < 1e12:
        return values / 1e9, "GFLOPs"
    return values / 1e12, "TFLOPs"


@dataclass(frozen=True)
class CostPoint:
    method: str
    steps: int
    batch_size: int
    flops: float
    mem_peak_bytes: float
    mem_label: str
    time_sec: float


_NCU_DISABLED_REASON: str | None = None


def _sync_if_cuda(device: torch.device) -> None:
    sync_cuda(device)


def _cuda_cleanup(device: torch.device) -> None:
    if device.type != "cuda" or not torch.cuda.is_available():
        return
    torch.cuda.synchronize(device)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.synchronize(device)


def _run_train_update(model: Any, inputs: torch.Tensor, targets: torch.Tensor, h_prev: torch.Tensor) -> None:
    if hasattr(model, "run_one_cycle_and_update_directly"):
        model.run_one_cycle_and_update_directly(inputs, targets, h_prev)
    else:
        model.train_batch(inputs, targets, h_prev)


def _warmup_train_batch(
    model: Any,
    inputs_batch: np.ndarray,
    targets_batch: np.ndarray,
    device: torch.device,
    *,
    warmup: int,
) -> None:
    from common.sequence_core import to_tensor

    warmup = max(0, int(warmup))
    if warmup <= 0:
        return
    inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
    targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
    batch_size = int(inputs.shape[0])
    h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)
    for _ in range(warmup):
        _run_train_update(model, inputs, targets, h_prev)
    _sync_if_cuda(device)


def _measure_update_memory(
    model: Any,
    inputs_batch: np.ndarray,
    targets_batch: np.ndarray,
    device: torch.device,
    *,
    mem_mode: str,
    cuda_mem: str,
) -> Tuple[float, str]:
    from common.sequence_core import to_tensor

    if device.type != "cuda" or not torch.cuda.is_available():
        inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
        targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)
        with RssSampler() as sampler:
            _run_train_update(model, inputs, targets, h_prev)
        peak_rss = int(sampler.peak_bytes)
        baseline_rss = int(sampler.baseline_bytes)
        if mem_mode == "peak":
            return float(peak_rss), "Peak RSS"
        return float(max(0, peak_rss - baseline_rss)), "Peak RSS delta"

    inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
    targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
    batch_size = int(inputs.shape[0])
    h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)

    stats = measure_cuda_peak_memory(
        lambda: _run_train_update(model, inputs, targets, h_prev),
        device=device,
        empty_cache=False,
    )
    peak_alloc = stats.peak_allocated_bytes
    peak_reserved = stats.peak_reserved_bytes
    delta_alloc = stats.delta_allocated_bytes
    delta_reserved = stats.delta_reserved_bytes

    if mem_mode == "peak":
        label = "GPU running mem."
    else:
        label = "GPU running mem. delta"

    if cuda_mem == "reserved":
        if mem_mode == "delta":
            return float(delta_reserved), label
        return float(peak_reserved), label
    if mem_mode == "delta":
        return float(delta_alloc), label
    return float(peak_alloc), label


def _measure_update_flops_ncu(
    script_path: Path,
    method: str,
    steps: int,
    args: argparse.Namespace,
    *,
    device: torch.device,
    repo_root: Path,
) -> float:
    if device.type != "cuda" or not torch.cuda.is_available():
        return float("nan")

    metrics = [m.strip() for m in str(args.ncu_metrics).split(",") if m.strip()]
    if not metrics:
        raise SystemExit("--ncu-metrics must contain at least one metric name.")

    range_name = str(args.nvtx_range) if getattr(args, "nvtx_range", None) else f"train_update_{slugify(method)}"
    target_cmd: List[str] = [
        sys.executable,
        str(script_path),
        "--single-method",
        str(method),
        "--single-steps",
        str(int(steps)),
        "--single-flops",
        "--nvtx-range",
        str(range_name),
        "--hidden",
        str(int(args.hidden)),
        "--batch-size",
        str(int(args.batch_size)),
        "--lr",
        str(float(args.lr)),
        "--seed",
        str(int(args.seed)),
        "--train-limit",
        str(int(args.train_limit)),
        "--tbptt",
        str(int(args.tbptt)),
        "--warmup",
        str(int(args.warmup)),
        "--time-repeats",
        str(int(args.time_repeats)),
        "--mem-mode",
        str(args.mem_mode),
        "--cuda-mem",
        str(args.cuda_mem),
    ]
    if args.device:
        target_cmd += ["--device", str(args.device)]

    flops = run_ncu_metrics(
        target_cmd,
        ncu_path=str(args.ncu_path),
        nvtx_include=f"{range_name}/",
        metrics=metrics,
        cwd=str(repo_root),
        env=os.environ,
        timeout_sec=float(args.ncu_timeout_sec) if float(args.ncu_timeout_sec) > 0 else None,
    )
    return float(flops.total)


def _measure_update_flops_torch_profiler(
    model: Any,
    inputs_batch: np.ndarray,
    targets_batch: np.ndarray,
    device: torch.device,
) -> float:
    from common.sequence_core import to_tensor

    try:
        from torch.profiler import ProfilerActivity, profile
    except Exception:
        return float("nan")

    inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
    targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
    batch_size = int(inputs.shape[0])
    h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)

    activities = [ProfilerActivity.CPU]
    if device.type == "cuda" and torch.cuda.is_available():
        activities.append(ProfilerActivity.CUDA)

    try:
        with profile(activities=activities, with_flops=True, record_shapes=True) as prof:
            _run_train_update(model, inputs, targets, h_prev)
            _sync_if_cuda(device)
    except Exception:
        return float("nan")

    events = prof.key_averages()
    total_flops = float(sum(float(getattr(evt, "flops", 0.0) or 0.0) for evt in events))
    if device.type != "cuda" or not torch.cuda.is_available():
        return total_flops if total_flops > 0.0 else float("nan")

    try:
        from torch.autograd.profiler_util import DeviceType

        cuda_flops = float(
            sum(
                float(getattr(evt, "flops", 0.0) or 0.0)
                for evt in events
                if getattr(evt, "device_type", None) == DeviceType.CUDA
            )
        )
    except Exception:
        cuda_flops = 0.0

    if cuda_flops > 0.0:
        return cuda_flops
    return total_flops if total_flops > 0.0 else float("nan")


def _time_train_batch(
    model: Any,
    inputs_batch: np.ndarray,
    targets_batch: np.ndarray,
    device: torch.device,
    *,
    repeats: int,
    warmup: int,
) -> float:
    from common.sequence_core import to_tensor

    repeats = max(1, int(repeats))
    warmup = max(0, int(warmup))
    inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
    targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
    batch_size = int(inputs.shape[0])
    h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)
    for _ in range(warmup):
        _run_train_update(model, inputs, targets, h_prev)
    _sync_if_cuda(device)
    elapsed_samples: List[float] = []
    for _ in range(repeats):
        _sync_if_cuda(device)
        start = time.perf_counter()
        _run_train_update(model, inputs, targets, h_prev)
        _sync_if_cuda(device)
        elapsed_samples.append(time.perf_counter() - start)
    elapsed_samples.sort()
    return float(elapsed_samples[len(elapsed_samples) // 2])


def _run_isolated_points(
    script_path: Path,
    method_names: List[str],
    steps_list: List[int],
    args: argparse.Namespace,
) -> List[CostPoint]:
    points: List[CostPoint] = []
    for steps in steps_list:
        for name in method_names:
            handle = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
            tmp_path = handle.name
            handle.close()

            cmd = [
                sys.executable,
                str(script_path),
                "--single-method",
                name,
                "--single-steps",
                str(int(steps)),
                "--single-out",
                tmp_path,
                "--hidden",
                str(int(args.hidden)),
                "--batch-size",
                str(int(args.batch_size)),
                "--lr",
                str(float(args.lr)),
                "--seed",
                str(int(args.seed)),
                "--train-limit",
                str(int(args.train_limit)),
                "--tbptt",
                str(int(args.tbptt)),
                "--warmup",
                str(int(args.warmup)),
                "--time-repeats",
                str(int(args.time_repeats)),
                "--mem-mode",
                str(args.mem_mode),
                "--cuda-mem",
                str(args.cuda_mem),
                "--flops-mode",
                str(args.flops_mode),
                "--ncu-path",
                str(args.ncu_path),
                "--ncu-metrics",
                str(args.ncu_metrics),
                "--ncu-timeout-sec",
                str(args.ncu_timeout_sec),
            ]
            if args.device:
                cmd.extend(["--device", str(args.device)])

            subprocess.check_call(cmd)
            with open(tmp_path, "r", encoding="utf-8") as handle_read:
                payload = json.load(handle_read)
            os.unlink(tmp_path)

            points.append(CostPoint(**payload))
    return points


def main(argv: List[str] | None = None) -> int:
    global _NCU_DISABLED_REASON
    parser = argparse.ArgumentParser()
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=512)
    parser.add_argument("--steps", type=str, default="16,32,64,128")
    parser.add_argument("--tbptt", type=int, default=1)
    parser.add_argument("--warmup", type=int, default=2)
    parser.add_argument("--time-repeats", type=int, default=5)
    parser.add_argument("--mem-mode", choices=["peak", "delta"], default="delta")
    parser.add_argument("--cuda-mem", choices=["reserved", "allocated"], default="allocated")
    parser.add_argument("--flops-mode", choices=["ncu", "torch", "none"], default="ncu")
    parser.add_argument("--ncu-path", type=str, default="ncu")
    parser.add_argument(
        "--ncu-metrics",
        type=str,
        default="flop_count_sp,flop_count_hp,flop_count_dp,flop_count_tensor",
        help="Comma-separated Nsight Compute metric names to collect and sum.",
    )
    parser.add_argument(
        "--ncu-timeout-sec",
        type=float,
        default=0.0,
        help="Optional timeout for ncu runs (0 disables).",
    )
    parser.add_argument("--out-dir", type=str, default=str(Path("plots") / "uci_har_cost_sweep"))
    parser.add_argument("--tag", type=str, default=None)
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--isolate", action="store_true")
    parser.add_argument("--single-method", type=str, default=None, help=argparse.SUPPRESS)
    parser.add_argument("--single-steps", type=int, default=None, help=argparse.SUPPRESS)
    parser.add_argument("--single-out", type=str, default=None, help=argparse.SUPPRESS)
    parser.add_argument("--single-flops", action="store_true", help=argparse.SUPPRESS)
    parser.add_argument("--nvtx-range", type=str, default=None, help=argparse.SUPPRESS)
    args = parser.parse_args(argv)

    repo_root = Path(__file__).resolve().parents[1]
    _ensure_task_path(repo_root)

    from common.sequence_core import (
        DEFAULT_DEVICE,
        EPROP_FEEDBACK,
        EPROP_SEED,
        FPTT_LAMBDA,
        FPTT_ORACLE_MOMENTUM,
        FPTT_PARTS,
        StrictFPTTClassifier,
        TorchBPTTRNN,
        TorchLocalRuleRNN,
        extract_params,
        load_params,
        set_global_seed,
    )
    from methods.standard_eprop import StandardEPropRNN
    from common.sequence_core import load_uci_har_sequences

    device = torch.device(args.device) if args.device else DEFAULT_DEVICE
    flops_mode = str(args.flops_mode).strip().lower()
    if device.type != "cuda" or not torch.cuda.is_available():
        if flops_mode == "ncu":
            flops_mode = "none"
    if flops_mode == "ncu" and device.type == "cuda" and torch.cuda.is_available():
        ncu_path = str(args.ncu_path)
        if shutil.which(ncu_path) is None and not Path(ncu_path).exists():
            raise SystemExit(f"Nsight Compute CLI not found: '{ncu_path}'. Install it or pass --ncu-path.")

    set_global_seed(int(args.seed))
    torch.set_num_threads(max(1, os.cpu_count() or 1))

    x_train, y_train, _, _, _, _ = load_uci_har_sequences(
        npz_path=None,
        train_limit=int(args.train_limit) if args.train_limit is not None else None,
        test_limit=1,
        auto_download=True,
        root=None,
    )
    if x_train.shape[0] == 0:
        raise SystemExit("Empty UCI HAR dataset after applying train-limit.")

    batch_size = min(int(args.batch_size), int(x_train.shape[0]))
    x_batch = x_train[:batch_size]
    y_batch = y_train[:batch_size]
    input_size = int(x_batch.shape[1])
    output_size = int(y_batch.shape[1])

    local_model = TorchLocalRuleRNN(
        input_size,
        int(args.hidden),
        output_size,
        eta=float(args.lr),
        loss_mode="ce",
        seed=int(args.seed),
        device=device,
    )
    local_model.initialize_weights_with_gain(1.0, seed=int(args.seed))
    init_params = extract_params(local_model)
    del local_model
    _cuda_cleanup(device)
    gc.collect()

    def build_local() -> Any:
        return TorchLocalRuleRNN(
            input_size,
            int(args.hidden),
            output_size,
            eta=float(args.lr),
            loss_mode="ce",
            seed=int(args.seed),
            device=device,
        )

    def build_bptt(tbptt_steps: int | None) -> Any:
        return TorchBPTTRNN(
            input_size,
            int(args.hidden),
            output_size,
            eta=float(args.lr),
            loss_mode="ce",
            tbptt_steps=tbptt_steps,
            time_normalization=False,
            seed=int(args.seed),
            device=device,
        )

    def build_eprop() -> Any:
        return StandardEPropRNN(
            input_size,
            int(args.hidden),
            output_size,
            eta=float(args.lr),
            # For cost-sweep (memory scaling vs steps), we benchmark the stop-gradient (detach) variant.
            # This matches the decay_lambda=0 setting used in the gradcheck and avoids allocating
            # per-example (B,H,H) eligibility traces.
            decay_lambda=0.0,
            feedback=EPROP_FEEDBACK,
            seed=int(EPROP_SEED),
            loss_mode="ce",
            device=device,
        )

    def build_fptt(steps: int) -> Any:
        # Keep chunk length roughly constant across different sequence lengths so that
        # FPTT activation memory does not scale with `steps` in this cost sweep.
        #
        # `StrictFPTT*` chunks a sequence into `parts` segments, each doing a local
        # backward pass over that chunk. With a fixed `parts`, chunk length grows
        # with `steps` and so does activation memory. Here we instead target a
        # chunk length ~= TBPTT window.
        chunk_len = max(1, int(args.tbptt))
        parts = max(1, int(math.ceil(float(int(steps)) / float(chunk_len))))
        return StrictFPTTClassifier(
            input_size,
            int(args.hidden),
            output_size,
            eta=float(args.lr),
            parts=parts,
            clip=5.0,
            lmbda=float(FPTT_LAMBDA),
            oracle_momentum=float(FPTT_ORACLE_MOMENTUM),
            label_mode="last",
            oracle_id="uci_har",
            use_oracle=True,
            device=device,
        )

    method_builders: List[Tuple[str, Any]] = [
        ("Local Rule", lambda steps: build_local()),
        ("BPTT", lambda steps: build_bptt(None)),
        ("E-Prop", lambda steps: build_eprop()),
        ("FPTT", build_fptt),
    ]

    # Include TBPTT-10 as a standard reference, plus the user-requested TBPTT window.
    tbptt_values: List[int] = []
    for v in (int(args.tbptt), 10):
        if v <= 0:
            continue
        if v not in tbptt_values:
            tbptt_values.append(v)
    for v in tbptt_values:
        method_builders.append((f"TBPTT-{v}", lambda steps, v=v: build_bptt(v)))

    method_map = {name: builder for name, builder in method_builders}
    if args.single_method is not None:
        if args.single_steps is None:
            raise SystemExit("--single-steps is required when using --single-method.")
        if args.single_method not in method_map:
            raise SystemExit(f"Unknown --single-method '{args.single_method}'.")
        method_builders = [(args.single_method, method_map[args.single_method])]
        steps_list = [int(args.single_steps)]
        args.isolate = False
    else:
        steps_list = _parse_int_list(args.steps)
    base_steps = int(x_batch.shape[2])
    max_steps = max(steps_list) if steps_list else base_steps
    if max_steps <= 0:
        raise SystemExit("Invalid --steps list.")

    if max_steps > base_steps and not args.single_flops:
        print(f"[INFO] steps sweep exceeds dataset length ({base_steps}); will tile sequences.")

    if args.single_flops:
        if args.single_method is None or args.single_steps is None:
            raise SystemExit("--single-flops requires --single-method and --single-steps.")
        builder = method_map[args.single_method]
        steps = int(args.single_steps)
        x_steps = _tile_to_length(x_batch, steps)
        y_steps = _tile_to_length(y_batch, steps)

        model = builder(steps)
        _cuda_cleanup(device)
        gc.collect()
        load_params(model, init_params)
        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()

        _warmup_train_batch(model, x_steps, y_steps, device, warmup=int(args.warmup))
        load_params(model, init_params)
        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()

        from common.sequence_core import to_tensor

        inputs = to_tensor(x_steps, device=device, dtype=torch.float32)
        targets = to_tensor(y_steps, device=device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        h_prev = torch.zeros((model.hidden_size, batch_size), dtype=torch.float32, device=device)

        range_name = str(args.nvtx_range) if args.nvtx_range else f"train_update_{slugify(args.single_method)}"
        with nvtx_range(range_name):
            _run_train_update(model, inputs, targets, h_prev)
        _sync_if_cuda(device)
        return 0

    tag = args.tag or time.strftime("%Y%m%d_%H%M%S")
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if args.isolate and device.type == "cuda" and torch.cuda.is_available() and args.single_method is None:
        points = _run_isolated_points(
            Path(__file__).resolve(),
            [name for name, _ in method_builders],
            steps_list,
            args,
        )
    else:
        points: List[CostPoint] = []
        for steps in steps_list:
            steps = int(steps)
            x_steps = _tile_to_length(x_batch, steps)
            y_steps = _tile_to_length(y_batch, steps)

            for name, builder in method_builders:
                model = builder(steps)
                _cuda_cleanup(device)
                gc.collect()
                load_params(model, init_params)
                if hasattr(model, "reset_learning_state"):
                    model.reset_learning_state()
                if hasattr(model, "reset_state_buffers"):
                    model.reset_state_buffers()

                _warmup_train_batch(model, x_steps, y_steps, device, warmup=int(args.warmup))
                load_params(model, init_params)
                if hasattr(model, "reset_learning_state"):
                    model.reset_learning_state()
                if hasattr(model, "reset_state_buffers"):
                    model.reset_state_buffers()

                mem_mode = str(args.mem_mode).lower()
                cuda_mem = str(args.cuda_mem).lower()
                mem_peak, mem_label = _measure_update_memory(
                    model,
                    x_steps,
                    y_steps,
                    device,
                    mem_mode=mem_mode,
                    cuda_mem=cuda_mem,
                )
                load_params(model, init_params)
                if hasattr(model, "reset_learning_state"):
                    model.reset_learning_state()
                if hasattr(model, "reset_state_buffers"):
                    model.reset_state_buffers()

                flops = float("nan")
                if flops_mode == "torch":
                    flops = _measure_update_flops_torch_profiler(model, x_steps, y_steps, device)
                elif flops_mode == "ncu":
                    if _NCU_DISABLED_REASON is not None:
                        flops = _measure_update_flops_torch_profiler(model, x_steps, y_steps, device)
                    else:
                        try:
                            flops = _measure_update_flops_ncu(
                                Path(__file__).resolve(),
                                method=str(name),
                                steps=int(steps),
                                args=args,
                                device=device,
                                repo_root=repo_root,
                            )
                        except RuntimeError as exc:
                            _NCU_DISABLED_REASON = str(exc).splitlines()[0].strip() or "ncu failed"
                            print(f"[WARN] ncu unavailable ({_NCU_DISABLED_REASON}); switching to torch profiler FLOPs.")
                            flops_mode = "torch"
                            flops = _measure_update_flops_torch_profiler(model, x_steps, y_steps, device)

                load_params(model, init_params)
                if hasattr(model, "reset_learning_state"):
                    model.reset_learning_state()
                if hasattr(model, "reset_state_buffers"):
                    model.reset_state_buffers()
                time_sec = _time_train_batch(
                    model,
                    x_steps,
                    y_steps,
                    device,
                    repeats=int(args.time_repeats),
                    warmup=int(args.warmup),
                )

                points.append(
                    CostPoint(
                        method=str(name),
                        steps=int(steps),
                        batch_size=int(batch_size),
                        flops=float(flops),
                        mem_peak_bytes=float(mem_peak),
                        mem_label=mem_label,
                        time_sec=float(time_sec),
                    )
                )
                print(
                    f"[COST] steps={steps:4d} | {name:10s} | "
                    f"flops={points[-1].flops:.3e} | mem={points[-1].mem_peak_bytes:.3e} B | time={time_sec:.4f}s"
                )
                del model
                _cuda_cleanup(device)
                gc.collect()

    if args.single_method is not None:
        if not points:
            raise SystemExit("No measurement recorded for --single-method run.")
        payload = points[0].__dict__
        if args.single_out:
            with open(args.single_out, "w", encoding="utf-8") as handle:
                json.dump(payload, handle, indent=2, ensure_ascii=False)
        else:
            print(json.dumps(payload, ensure_ascii=False))
        return 0

    json_path = out_dir / f"{tag}_uci_har_cost_sweep.json"
    payload = {
        "schema_version": 1,
        "task": "UCI HAR (Human Activity Recognition)",
        "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "device": str(device),
        "flops_mode": str(flops_mode),
        "ncu_path": str(args.ncu_path),
        "ncu_metrics": str(args.ncu_metrics),
        "hidden": int(args.hidden),
        "batch_size": int(batch_size),
        "lr": float(args.lr),
        "seed": int(args.seed),
        "steps_list": steps_list,
        "tbptt": int(args.tbptt),
        "mem_mode": str(args.mem_mode),
        "cuda_mem": str(args.cuda_mem),
        "mem_label": points[0].mem_label if points else "Peak RSS delta",
        "points": [p.__dict__ for p in points],
    }
    with json_path.open("w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2, ensure_ascii=False)

    csv_path = out_dir / f"{tag}_uci_har_cost_sweep.csv"
    with csv_path.open("w", encoding="utf-8", newline="") as handle:
        writer = csv.DictWriter(
            handle,
            fieldnames=["steps", "method", "batch_size", "flops", "mem_peak_bytes", "mem_label", "time_sec"],
        )
        writer.writeheader()
        for p in points:
            writer.writerow(
                {
                    "steps": p.steps,
                    "method": p.method,
                    "batch_size": p.batch_size,
                    "flops": p.flops,
                    "mem_peak_bytes": p.mem_peak_bytes,
                    "mem_label": p.mem_label,
                    "time_sec": p.time_sec,
                }
            )

    try:
        import matplotlib.pyplot as plt
        from matplotlib.lines import Line2D
    except Exception as exc:
        print(f"[PLOT] Skipping plots: matplotlib unavailable ({exc}).")
        print(f"[OK] Wrote: {json_path}")
        print(f"[OK] Wrote: {csv_path}")
        return 0

    plt.rcParams.update(
        {
            # Academic/Formal style
            "font.family": ["Times New Roman", "Times", "serif"],
            "font.size": 12,
            "axes.labelsize": 12,
            "axes.titlesize": 16,       # 调整了标题大小以适配正式风格
            "axes.titleweight": "normal", # 明确指定不加粗
            "xtick.labelsize": 11,
            "ytick.labelsize": 11,
            "legend.fontsize": 12,
            "axes.linewidth": 1.0,      # 线条稍微变细
            "lines.linewidth": 1.5,
            "lines.markersize": 5.0,
            # "path.sketch": ...        # 已移除手绘效果
        }
    )
    methods_list = [name for name, _ in method_builders]

    def _color_for_method(name: str) -> str:
        name = str(name)
        if name.startswith("TBPTT-"):
            return "#F1C40F"  # yellow
        mapping = {
            "Local Rule": "#1F77B4",  # blue
            "BPTT": "#FF7F0E",        # orange
            "E-Prop": "#2CA02C",      # green
            "FPTT": "#E377C2",        # pink
        }
        return mapping.get(name, "#7F7F7F")  # gray fallback

    colors = [_color_for_method(m) for m in methods_list]
    steps_arr = np.asarray(steps_list, dtype=int)
    x_group = np.arange(len(steps_arr), dtype=float)
    width = 0.8 / max(1, len(methods_list))

    def _grid(ax: Any) -> None:
        ax.grid(True, axis="y", linestyle=":", linewidth=1.0, alpha=0.5, zorder=0)
        ax.set_axisbelow(True)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    def _points_matrix(field: str) -> np.ndarray:
        mat = np.full((len(methods_list), len(steps_arr)), np.nan, dtype=float)
        for mi, m in enumerate(methods_list):
            for si, s in enumerate(steps_arr):
                for p in points:
                    if p.method == m and p.steps == int(s):
                        mat[mi, si] = float(getattr(p, field))
                        break
        return mat

    def _plot_sorted_group_bars(ax: Any, values: np.ndarray) -> List[Line2D]:
        # Sort bars low->high within each step group for readability, while keeping
        # method-specific trend lines aligned with their (sorted) bar positions.
        m_count, s_count = values.shape
        x_by_method = np.full((m_count, s_count), np.nan, dtype=float)

        for si in range(s_count):
            col = values[:, si]
            finite_idx = np.nonzero(np.isfinite(col))[0]
            if finite_idx.size == 0:
                continue
            order = finite_idx[np.argsort(col[finite_idx], kind="stable")]
            for rank, mi in enumerate(order):
                xpos = x_group[si] + (rank - (len(order) - 1) / 2.0) * width
                x_by_method[mi, si] = xpos
                ax.bar(
                    xpos,
                    float(col[mi]),
                    width=width,
                    color=colors[mi],
                    edgecolor="black",
                    linewidth=1.1,
                    alpha=0.85,
                    zorder=3,
                )

        # Trend lines: plot through the actual bar positions after per-group sorting.
        for mi in range(m_count):
            mask = np.isfinite(x_by_method[mi]) & np.isfinite(values[mi])
            if not np.any(mask):
                continue
            ax.plot(
                x_by_method[mi, mask],
                values[mi, mask],
                color=colors[mi],
                linestyle="-",
                marker="o",
                markeredgecolor="black",
                markeredgewidth=0.8,
                zorder=4,
            )

        handles: List[Line2D] = [
            Line2D(
                [0],
                [0],
                color=colors[mi],
                marker="o",
                linestyle="-",
                markeredgecolor="black",
                markeredgewidth=0.8,
                label=methods_list[mi],
            )
            for mi in range(m_count)
        ]
        return handles

    mem_label = points[0].mem_label if points else "Peak RSS delta"
    mem_bytes = _points_matrix("mem_peak_bytes")
    flops = _points_matrix("flops")
    time_sec = _points_matrix("time_sec")

    mem_scaled, mem_unit = _scale_bytes(mem_bytes)
    time_scaled, time_unit = _scale_time_sec(time_sec)
    flops_scaled, flops_unit = _scale_flops(flops)

    # Paper-friendly axis labels (avoid task-specific naming in figures).
    if device.type == "cuda" and torch.cuda.is_available():
        mem_kind = "allocated" if str(args.cuda_mem).lower() == "allocated" else "reserved"
        if str(args.mem_mode).lower() == "delta":
            mem_axis_label = f"Δ GPU {mem_kind} memory"
        else:
            mem_axis_label = f"GPU peak {mem_kind} memory"
    else:
        mem_axis_label = "Δ RSS" if str(args.mem_mode).lower() == "delta" else "Peak RSS"

    def _finish(fig: Any, title: str, legend_handles: List[Line2D]) -> None:
        fig.legend(
            handles=legend_handles,
            loc="upper center",
            ncols=3,
            frameon=False,
            bbox_to_anchor=(0.5, 0.995),
            handlelength=2.2,
            columnspacing=2.0,
        )
        fig.suptitle(title, y=0.84, fontsize=26, fontweight="bold")
        fig.subplots_adjust(top=0.70, bottom=0.15)

    fig, ax = plt.subplots(figsize=(8.2, 5.1))
    handles = _plot_sorted_group_bars(ax, mem_scaled)
    ax.set_xticks(x_group)
    ax.set_xticklabels([str(s) for s in steps_arr], rotation=0)
    ax.set_xlabel("Steps")
    ax.set_ylabel(f"{mem_axis_label} ({mem_unit})")
    _grid(ax)
    _finish(fig, "Update Memory vs Steps", handles)
    mem_path = out_dir / f"{tag}_uci_har_mem_vs_steps.png"
    fig.savefig(mem_path, dpi=300, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(8.2, 5.1))
    handles = _plot_sorted_group_bars(ax, time_scaled)
    ax.set_xticks(x_group)
    ax.set_xticklabels([str(s) for s in steps_arr], rotation=0)
    ax.set_xlabel("Steps")
    ax.set_ylabel(f"Time per update ({time_unit})")
    _grid(ax)
    _finish(fig, "Update Time vs Steps", handles)
    time_path = out_dir / f"{tag}_uci_har_time_vs_steps.png"
    fig.savefig(time_path, dpi=300, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(8.2, 5.1))
    handles = _plot_sorted_group_bars(ax, flops_scaled)
    ax.set_xticks(x_group)
    ax.set_xticklabels([str(s) for s in steps_arr], rotation=0)
    ax.set_xlabel("Steps")
    ax.set_ylabel(f"FLOPs per update ({flops_unit})")
    _grid(ax)
    _finish(fig, "Update FLOPs vs Steps", handles)
    flops_path = out_dir / f"{tag}_uci_har_flops_vs_steps.png"
    fig.savefig(flops_path, dpi=300, bbox_inches="tight", pad_inches=0.2)
    plt.close(fig)

    print(f"[OK] Wrote: {json_path}")
    print(f"[OK] Wrote: {csv_path}")
    print(f"[OK] Wrote: {mem_path}")
    print(f"[OK] Wrote: {time_path}")
    print(f"[OK] Wrote: {flops_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
