import os
import csv
import json
import time
import torch
import pynvml
import threading
import atexit
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any, Tuple
from pathlib import Path

from codecarbon import EmissionsTracker, OfflineEmissionsTracker
from .environment import save_environment
from .utils import _safe_filename, _write_json


def _infer_nvml_device_indices(device_count: int) -> List[int]:
    """
    Infer which NVML GPU indices to poll based on environment variables.
    Handles common Slurm patterns + cgroup remapping.
    """
    cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
    if not cvd:
        return list(range(device_count))

    tokens = [t.strip() for t in cvd.split(",") if t.strip()]

    # CUDA_VISIBLE_DEVICES can be UUID-form ("GPU-...") on some systems.
    # If UUIDs are present, safest fallback is to poll visible indices only.
    if any(t.startswith("GPU-") or t.startswith("MIG-") for t in tokens):
        return list(range(device_count))

    # Numeric form
    try:
        inds = [int(t) for t in tokens]
    except ValueError:
        return list(range(device_count))

    # If cgroups/remapping expose only N GPUs, NVML count may be N, but CVD might still contain
    # physical indices (rare) or logical ones. If any index is out of range, treat as remapped.
    if not inds:
        return list(range(device_count))
    if max(inds) >= device_count:
        return list(range(device_count))

    # Otherwise treat as physical indices
    return inds


@dataclass
class StageMetrics:
    """Metrics collected for a single stage of the pipeline."""

    stage_id: str
    stage_name: str
    start_time: float = 0.0
    end_time: float = 0.0
    duration_seconds: float = 0.0

    # Token counts
    tokens_processed: int = 0

    # GPU metrics (from NVML polling)
    gpu_energy_joules: float = 0.0
    gpu_avg_power_watts: float = 0.0
    gpu_peak_power_watts: float = 0.0
    gpu_power_samples: List[float] = field(default_factory=list)
    
    # CodeCarbon metrics
    # energy_consumed = cpu_energy + gpu_energy + ram_energy (all kWh)
    total_codecarbon_energy_kwh: float = 0.0
    codecarbon_emissions_kg: float = 0.0
    codecarbon_cpu_energy_kwh: float = 0.0
    codecarbon_gpu_energy_kwh: float = 0.0
    codecarbon_ram_energy_kwh: float = 0.0

    # CPU + total
    cpu_energy_joules: float = 0.0
    total_energy_joules: float = 0.0
    total_energy_kwh: float = 0.0

    # Derived metrics
    joules_per_token: float = 0.0
    kwh_total: float = 0.0
    tokens_per_second: float = 0.0

    def compute_derived_metrics(self, total_energy_policy: str = "measured"):
        """Compute derived metrics after stage completion."""
        self.duration_seconds = self.end_time - self.start_time

        if self.duration_seconds > 0 and self.tokens_processed > 0:
            self.tokens_per_second = self.tokens_processed / self.duration_seconds

        gpu_j = self.gpu_energy_joules
        cpu_j = self.cpu_energy_joules
        cc_kwh = self.total_codecarbon_energy_kwh
        cc_j = cc_kwh * 3_600_000 if cc_kwh > 0 else 0.0

        # Total energy selection policy
        if total_energy_policy == "codecarbon" and cc_j > 0:
            self.total_energy_joules = cc_j
        elif total_energy_policy == "gpu_only" and gpu_j > 0:
            self.total_energy_joules = gpu_j
        else:
            # "measured": prefer measured GPU+CPU; else fall back to CodeCarbon
            if gpu_j > 0 or cpu_j > 0:
                self.total_energy_joules = gpu_j + cpu_j
            elif cc_j > 0:
                self.total_energy_joules = cc_j
            else:
                self.total_energy_joules = 0.0

        self.total_energy_kwh = self.total_energy_joules / 3_600_000 if self.total_energy_joules > 0 else 0.0

        # Backwards-compatible kWh field:
        # use measured total if we have it; else fall back to CodeCarbon or GPU-only.
        if self.total_energy_kwh > 0:
            self.kwh_total = self.total_energy_kwh
        elif cc_kwh > 0:
            self.kwh_total = cc_kwh
        elif gpu_j > 0:
            self.kwh_total = gpu_j / 3_600_000
        else:
            self.kwh_total = 0.0

        if self.tokens_processed > 0 and self.total_energy_joules > 0:
            self.joules_per_token = self.total_energy_joules / self.tokens_processed
        else:
            self.joules_per_token = 0.0

    def to_dict(self, include_power_samples: bool = False) -> Dict[str, Any]:
        data = asdict(self)
        if not include_power_samples:
            data.pop("gpu_power_samples", None)
        return data


class RAPLReader:
    """
    Minimal RAPL wrapper to read package (and optional dram) energy via powercap.
    """

    def __init__(self, rapl_root: Path, domains: List[str] = ["package"]):
        self.domains = domains
        self.rapl_root = rapl_root
        self.start_uj: Dict[str, int] = {}
        self.end_uj: Dict[str, int] = {}
        self.max_range_uj: Dict[str, int] = {}
        self.available = self.rapl_root.exists()

    def _read_int(self, p: Path) -> int:
        try:
            return int(p.read_text().strip())
        except PermissionError as e:
            raise PermissionError(f"RAPL permission denied reading {p}") from e

    def _iter_zones(self):
        # top-level zones: intel-rapl:0, intel-rapl:1, ...
        for zone in self.rapl_root.glob("intel-rapl:*"):
            # ignore subzones here
            if zone.name.count(":") != 1:
                continue
            yield zone

    def _collect_raw(self) -> Dict[str, int]:
        """
        Return energy_uj sums per domain across packages.
        Keys: "package", "dram" if requested and present.
        """
        out = {d: 0 for d in self.domains}
        if not self.available:
            return out

        for zone in self._iter_zones():
            # package energy
            if "package" in self.domains:
                e_uj = self._read_int(zone / "energy_uj")
                out["package"] += e_uj
                self.max_range_uj.setdefault("package", self._read_int(zone / "max_energy_range_uj"))

            # dram subzones (optional)
            if "dram" in self.domains:
                for sub in zone.glob("intel-rapl:*:*"):
                    name_file = sub / "name"
                    try:
                        name = name_file.read_text().strip().lower()
                    except Exception:
                        continue
                    if "dram" in name:
                        e_uj = self._read_int(sub / "energy_uj")
                        out["dram"] += e_uj
                        self.max_range_uj.setdefault("dram", self._read_int(sub / "max_energy_range_uj"))
        return out

    def start(self):
        if not self.available:
            return
        self.start_uj = self._collect_raw()

    def stop(self) -> Dict[str, float]:
        """
        Returns energy deltas in Joules per domain.
        """
        if not self.available:
            return {d: 0.0 for d in self.domains}
        self.end_uj = self._collect_raw()
        deltas_j: Dict[str, float] = {}
        for d in self.domains:
            s = self.start_uj.get(d, 0)
            e = self.end_uj.get(d, 0)
            mr = self.max_range_uj.get(d, 0)
            if s == 0 and e == 0:
                deltas_j[d] = 0.0
                continue
            if mr > 0 and e < s:
                # wrapped
                delta_uj = e + (mr - s)
            else:
                delta_uj = e - s
            deltas_j[d] = delta_uj / 1e6  # microjoules -> joules
        return deltas_j


class NVMLPoller:
    """Background thread for polling GPU power via NVML."""

    def __init__(
        self,
        poll_interval_ms: int = 500,
        device_indices: Optional[List[int]] = None,
        power_log_path: Optional[Path] = None,
    ):
        self.poll_interval_sec = poll_interval_ms / 1000.0
        self.device_indices = device_indices
        self.power_log_path = Path(power_log_path) if power_log_path else None
        self.power_readings: List[Dict[str, float]] = []
        self._stop_event = threading.Event()
        self._thread: Optional[threading.Thread] = None
        self._handles: List[Any] = []
        self._initialized = False
        self._lock = threading.Lock()
        self._log_file = None
        self._log_writer = None

    def start(self):
        """Initialize NVML and start polling thread."""
        try:
            pynvml.nvmlInit()
            self._initialized = True

            # Get device handles
            device_count = pynvml.nvmlDeviceGetCount()
            if self.device_indices is None:
                self.device_indices = _infer_nvml_device_indices(device_count)

            self._handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in self.device_indices]

            # Prepare power log file if requested
            if self.power_log_path:
                try:
                    self.power_log_path.parent.mkdir(parents=True, exist_ok=True)
                    self._log_file = open(self.power_log_path, "a", newline="")
                    self._log_writer = csv.writer(self._log_file)
                    if self.power_log_path.stat().st_size == 0:
                        header = ["timestamp"] + [f"gpu_{i}_power_w" for i in self.device_indices] + ["total_power_w"]
                        self._log_writer.writerow(header)
                        self._log_file.flush()
                except Exception as e:
                    print(f"[NVMLPoller] Warning: failed to open power log at {self.power_log_path}: {e}")
                    self._log_file = None
                    self._log_writer = None

            try:
                uuids = []
                for h in self._handles:
                    uuids.append(pynvml.nvmlDeviceGetUUID(h))
                print(f"[NVMLPoller] Polling GPUs: indices={self.device_indices}, uuids={uuids}")
            except Exception:
                pass

            self._stop_event.clear()
            self._thread = threading.Thread(target=self._poll_loop, daemon=True)
            self._thread.start()

        except pynvml.NVMLError as e:
            print(f"Warning: NVML initialization failed: {e}")
            self._initialized = False

    def _poll_loop(self):
        """Polling loop running in background thread."""
        while not self._stop_event.is_set():
            try:
                timestamp = time.time()
                powers = {}
                total_power = 0.0

                for idx, handle in zip(self.device_indices, self._handles):
                    power_mw = pynvml.nvmlDeviceGetPowerUsage(handle)
                    power_w = power_mw / 1000.0
                    powers[f"gpu_{idx}"] = power_w
                    total_power += power_w

                powers["timestamp"] = timestamp
                powers["total_power_w"] = total_power
                with self._lock:
                    self.power_readings.append(powers)
                if self._log_writer:
                    try:
                        row = [timestamp] + [powers.get(f"gpu_{i}", 0.0) for i in self.device_indices] + [total_power]
                        self._log_writer.writerow(row)
                        self._log_file.flush()
                    except Exception:
                        # Do not crash the poller if logging fails
                        pass

            except pynvml.NVMLError:
                pass  # Skip failed readings

            self._stop_event.wait(self.poll_interval_sec)

    def stop(self) -> List[Dict[str, float]]:
        """Stop polling and return collected readings."""
        if self._thread is not None:
            self._stop_event.set()
            self._thread.join(timeout=2.0)
            self._thread = None

        with self._lock:
            readings = self.power_readings.copy()
            self.power_readings.clear()
        if self._log_file:
            try:
                self._log_file.flush()
                self._log_file.close()
            except Exception:
                pass
            self._log_file = None
            self._log_writer = None
        return readings

    def get_current_readings(self) -> List[Dict[str, float]]:
        """Get readings collected so far without stopping."""
        with self._lock:
            return self.power_readings.copy()


class EnergyTracker:
    """
    Unified energy tracker for distillation experiments.

    Wraps CodeCarbon, NVML polling, and optional RAPL CPU readings for
    energy measurement with stage-wise accounting.

    Usage:
        tracker = EnergyTracker(run_dir, config=config)

        tracker.start_stage("teacher_forward")
        # ... do work ...
        tracker.end_stage(tokens_processed=1000)

        tracker.start_stage("student_train")
        # ... do work ...
        tracker.end_stage(tokens_processed=5000)

        tracker.save_summary()
    """

    def __init__(
        self,
        run_dir: Optional[str] = None,
        experiment_name: Optional[str] = None,
        config: Optional[Any] = None,
        nvml_poll_interval_ms: Optional[int] = None,
        track_cpu: Optional[bool] = None,
        country_iso_code: Optional[str] = None,
        offline_mode: Optional[bool] = None,
        rapl_root: Optional[str] = None,
    ):
        self.config = config
        self._stage_counts: Dict[str, int] = {}

        def _cfg(explicit: Any, attr_name: str, dotted: str, default: Any) -> Any:
            if explicit is not None:
                return explicit
            if self.config is not None:
                if hasattr(self.config, attr_name):
                    val = getattr(self.config, attr_name)
                    if val is not None:
                        return val
                if hasattr(self.config, "get"):
                    nested_val = self.config.get(dotted, None)
                    if nested_val is not None:
                        return nested_val
            return default

        self.experiment_name = _cfg(None, "experiment_name", "experiment.name", "experiment")
        self.nvml_poll_interval_ms = _cfg(nvml_poll_interval_ms, "energy_nvml_poll_ms", "energy.nvml_poll_interval_ms", 500)
        self.country_iso_code = _cfg(country_iso_code, "energy_country_iso", "energy.country_iso_code", "USA")
        self.offline_mode = _cfg(offline_mode, "energy_offline_mode", "energy.offline_mode", True)
        self.track_cpu = _cfg(track_cpu, "energy_track_cpu", "energy.track_cpu", True)
        rapl_root_val = _cfg(rapl_root, "energy_rapl_root", "energy.rapl_root", "/sys/class/powercap/intel-rapl")
        self.rapl_root = Path(rapl_root_val)
        self.total_energy_policy = _cfg(None, "energy_total_policy", "energy.total_energy_policy", "measured")

        self.config = config

        # run directories
        resolved_run_dir = _cfg(run_dir, "run_dir", "output.run_dir", None)
        if not resolved_run_dir:
            raise ValueError("EnergyTracker requires run_dir (preferred) or output_dir.")
        self.run_dir = Path(resolved_run_dir)

        self.energy_root = self.run_dir
        self.stages_dir = self.energy_root / "stages"
        self.codecarbon_dir = self.energy_root / "codecarbon"
        self.stages_dir.mkdir(parents=True, exist_ok=True)
        self.codecarbon_dir.mkdir(parents=True, exist_ok=True)

        # Stage tracking
        self.stages: Dict[str, StageMetrics] = {}
        self.current_stage: Optional[str] = None
        self._stage_dirs: Dict[str, Path] = {}

        # Tool instances (created per-stage)
        self._codecarbon_tracker: Optional[EmissionsTracker] = None
        self._nvml_poller: Optional[NVMLPoller] = None
        self._rapl_reader = RAPLReader(rapl_root=self.rapl_root, domains=["package"]) if self.track_cpu else None

        # Experiment metadata
        self.start_time = datetime.now().isoformat()
        self.experiment_id = f"{self.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

        # Register cleanup
        atexit.register(self._cleanup)

    def start_stage(self, stage_name: str):
        """Start tracking a new stage."""
        if self.current_stage is not None:
            print(f"Warning: Stage '{self.current_stage}' still running. Ending it first.")
            self.end_stage()

        stage_full_name = self.experiment_name or stage_name

        n = self._stage_counts.get(stage_full_name, 0) + 1
        self._stage_counts[stage_full_name] = n
        stage_id = stage_full_name if n == 1 else f"{stage_full_name}__{n}"

        self.current_stage = stage_id
        stage_metrics = StageMetrics(stage_id=stage_id, stage_name=stage_full_name)
        self.stages[stage_id] = stage_metrics

        # Prepare per-stage folder under stages/
        safe_stage_name = _safe_filename(stage_full_name)
        safe_stage_id = _safe_filename(stage_id)
        stage_dir = self.stages_dir / safe_stage_id
        stage_dir.mkdir(parents=True, exist_ok=True)
        self._stage_dirs[stage_id] = stage_dir

        # Snapshot environment and config for this stage
        try:
            env_filename = f"environment_{safe_stage_name}__{safe_stage_id}.json"
            save_environment(output_dir=str(stage_dir), filename=env_filename)
        except Exception as e:
            print(f"[EnergyTracker] Warning: failed to save environment for stage '{stage_id}': {e}")

        if self.config is not None:
            try:
                cfg_dict = self.config.to_dict() if hasattr(self.config, "to_dict") else dict(self.config)
                cfg_payload = {
                    "stage_name": stage_full_name,
                    "stage_id": stage_id,
                    "config": cfg_dict,
                }
                cfg_filename = f"config_{safe_stage_name}__{safe_stage_id}.json"
                _write_json(stage_dir / cfg_filename, cfg_payload)
            except Exception as e:
                print(f"[EnergyTracker] Warning: failed to save config for stage '{stage_id}': {e}")

        # RAPL CPU energy: use it if available, otherwise fall back to CodeCarbon
        if self.track_cpu and self._rapl_reader is not None:
            if not self._rapl_reader.available:
                print("[EnergyTracker] RAPL not available; will rely on CodeCarbon CPU estimates.")
                self._rapl_reader = None
            else:
                try:
                    self._rapl_reader.start()
                except PermissionError as e:
                    print(f"[EnergyTracker] RAPL permission denied; will rely on CodeCarbon CPU estimates. ({e})")
                    self._rapl_reader = None
                except Exception as e:
                    print(f"[EnergyTracker] RAPL start failed; will rely on CodeCarbon CPU estimates. ({e})")
                    self._rapl_reader = None

        # Start CodeCarbon
        try:
            project_name = stage_id

            # Common kwargs for both offline and online trackers
            cc_kwargs = {
                "project_name": project_name,
                "output_dir": str(self.codecarbon_dir),
                "country_iso_code": self.country_iso_code,
                "log_level": "error",
                # process-level accounting so other users' CPU doesn't get picked up
                "tracking_mode": "process",
                "gpu_ids": os.environ.get("CUDA_VISIBLE_DEVICES", None),
            }

            if self.offline_mode:
                self._codecarbon_tracker = OfflineEmissionsTracker(**cc_kwargs)
            else:
                self._codecarbon_tracker = EmissionsTracker(**cc_kwargs)

            self._codecarbon_tracker.start()
        except Exception as e:
            print(f"Warning: CodeCarbon start failed: {e}")
            self._codecarbon_tracker = None

        # Sync GPU before timing
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        # Start NVML polling
        self._nvml_poller = NVMLPoller(
            poll_interval_ms=self.nvml_poll_interval_ms,
            power_log_path=stage_dir / "power_log.csv",
        )
        self._nvml_poller.start()

        stage_metrics.start_time = time.time()

        print(f"[EnergyTracker] Started stage: {stage_id}")

    def end_stage(self, tokens_processed: Optional[int] = None) -> StageMetrics:
        """End current stage and collect metrics."""
        if self.current_stage is None:
            raise RuntimeError("No stage is currently running")

        # Sync GPU before timing
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        stage_id = self.current_stage
        stage_metrics = self.stages[stage_id]
        if tokens_processed is not None:
            stage_metrics.tokens_processed = tokens_processed
        stage_metrics.end_time = time.time()
        stage_metrics.duration_seconds = stage_metrics.end_time - stage_metrics.start_time

        # Stop NVML and compute GPU energy
        if self._nvml_poller is not None:
            readings = self._nvml_poller.stop()
            self._nvml_poller = None

            if readings:
                energy_j, power_samples, avg_power, peak_power = self._integrate_power(
                    readings, stage_metrics.start_time, stage_metrics.end_time
                )
                stage_metrics.gpu_power_samples = power_samples
                stage_metrics.gpu_energy_joules = energy_j
                stage_metrics.gpu_avg_power_watts = avg_power
                stage_metrics.gpu_peak_power_watts = peak_power

        # Stop CodeCarbon
        if self._codecarbon_tracker is not None:
            try:
                emissions = self._codecarbon_tracker.stop()
                if emissions is not None:
                    stage_metrics.codecarbon_emissions_kg = float(emissions)
            except Exception as e:
                print(f"Warning: CodeCarbon stop failed: {e}")
            finally:
                self._codecarbon_tracker = None

            # Read per-component energy from emissions.csv (all kWh)
            cc_metrics = self._read_codecarbon_metrics(
                self.codecarbon_dir,
                project_name=stage_id,
            )
            if cc_metrics is not None:
                stage_metrics.total_codecarbon_energy_kwh = cc_metrics["energy_consumed_kwh"]
                stage_metrics.codecarbon_cpu_energy_kwh = cc_metrics["cpu_energy_kwh"]
                stage_metrics.codecarbon_gpu_energy_kwh = cc_metrics["gpu_energy_kwh"]
                stage_metrics.codecarbon_ram_energy_kwh = cc_metrics["ram_energy_kwh"]

                # If RAPL didn't populate CPU, fall back to CodeCarbon's CPU estimate
                if stage_metrics.cpu_energy_joules <= 0.0 and cc_metrics["cpu_energy_kwh"] > 0.0:
                    stage_metrics.cpu_energy_joules = cc_metrics["cpu_energy_kwh"] * 3_600_000.0  # kWh -> J
        
        # RAPL CPU energy
        if self.track_cpu and self._rapl_reader is not None:
            try:
                deltas = self._rapl_reader.stop()
                pkg_j = deltas.get("package", 0.0)
                stage_metrics.cpu_energy_joules = pkg_j
            except Exception as e:
                print(f"[EnergyTracker] RAPL stop failed: {e}")

        # Compute derived metrics
        stage_metrics.compute_derived_metrics(total_energy_policy=self.total_energy_policy)

        # Save stage metrics to JSON
        safe_stage_id = _safe_filename(stage_id)
        stage_dir = self._stage_dirs.get(stage_id, self.stages_dir / safe_stage_id)
        stage_dir.mkdir(parents=True, exist_ok=True)
        stage_path = stage_dir / f"{safe_stage_id}.json"
        payload = stage_metrics.to_dict(include_power_samples=True)
        payload["nvml_poll_interval_ms"] = self.nvml_poll_interval_ms
        _write_json(stage_path, payload)

        print(f"[EnergyTracker] Ended stage: {stage_id}")
        print(f"  Duration: {stage_metrics.duration_seconds:.2f}s")
        print(f"  GPU Energy: {stage_metrics.gpu_energy_joules:.2f} J")
        print(f"  Tokens: {stage_metrics.tokens_processed}, Joules/token: {stage_metrics.joules_per_token:.4f}")

        self.current_stage = None
        return stage_metrics

    def snapshot_stage(self, step: Optional[int] = None, suffix: str = "checkpoint") -> Optional[Path]:
        """Write a partial snapshot for the running stage without stopping it."""
        if self.current_stage is None:
            print("[EnergyTracker] No active stage to snapshot.")
            return None

        stage_id = self.current_stage
        now = time.time()
        # Clone current metrics so we don't mutate live stage state
        snapshot_metrics = StageMetrics(**asdict(self.stages[stage_id]))
        snapshot_metrics.end_time = now
        snapshot_metrics.duration_seconds = now - snapshot_metrics.start_time

        # GPU energy from current readings
        readings = self._nvml_poller.get_current_readings() if self._nvml_poller else []
        energy_j, power_samples, avg_power, peak_power = self._integrate_power(
            readings, snapshot_metrics.start_time, now
        )
        snapshot_metrics.gpu_energy_joules = energy_j
        snapshot_metrics.gpu_power_samples = power_samples
        snapshot_metrics.gpu_avg_power_watts = avg_power
        snapshot_metrics.gpu_peak_power_watts = peak_power

        # CodeCarbon partial metrics if available
        cc_metrics = self._read_codecarbon_metrics(self.codecarbon_dir, project_name=stage_id)
        if cc_metrics is not None:
            snapshot_metrics.total_codecarbon_energy_kwh = cc_metrics["energy_consumed_kwh"]
            snapshot_metrics.codecarbon_cpu_energy_kwh = cc_metrics["cpu_energy_kwh"]
            snapshot_metrics.codecarbon_gpu_energy_kwh = cc_metrics["gpu_energy_kwh"]
            snapshot_metrics.codecarbon_ram_energy_kwh = cc_metrics["ram_energy_kwh"]
            if snapshot_metrics.cpu_energy_joules <= 0.0 and cc_metrics["cpu_energy_kwh"] > 0.0:
                snapshot_metrics.cpu_energy_joules = cc_metrics["cpu_energy_kwh"] * 3_600_000.0

        # Partial CPU energy via RAPL without stopping it
        if self.track_cpu and self._rapl_reader is not None:
            try:
                current_raw = self._rapl_reader._collect_raw()
                pkg_start = self._rapl_reader.start_uj.get("package", 0)
                pkg_end = current_raw.get("package", 0)
                pkg_range = self._rapl_reader.max_range_uj.get("package", 0)
                if pkg_start or pkg_end:
                    if pkg_range > 0 and pkg_end < pkg_start:
                        delta_uj = pkg_end + (pkg_range - pkg_start)
                    else:
                        delta_uj = pkg_end - pkg_start
                    snapshot_metrics.cpu_energy_joules = max(0.0, delta_uj / 1e6)
            except Exception as e:
                print(f"[EnergyTracker] RAPL partial read failed: {e}")

        snapshot_metrics.compute_derived_metrics(total_energy_policy=self.total_energy_policy)

        # Persist snapshot JSON
        safe_stage_id = _safe_filename(stage_id)
        stage_dir = self._stage_dirs.get(stage_id, self.stages_dir / safe_stage_id)
        stage_dir.mkdir(parents=True, exist_ok=True)
        step_label = f"step_{step}" if step is not None else "partial"
        snapshot_filename = f"{safe_stage_id}__{_safe_filename(step_label)}.json"
        snapshot_path = stage_dir / snapshot_filename
        payload = snapshot_metrics.to_dict(include_power_samples=True)
        payload["nvml_poll_interval_ms"] = self.nvml_poll_interval_ms
        payload["snapshot"] = True
        payload["snapshot_type"] = suffix
        payload["snapshot_step"] = step
        payload["snapshot_time"] = datetime.now().isoformat()
        _write_json(snapshot_path, payload)

        print(f"[EnergyTracker] Snapshot saved for stage '{stage_id}' at step {step}: {snapshot_path}")
        return snapshot_path

    def add_tokens(self, count: int):
        """Add tokens to current stage's count (for incremental updates)."""
        if self.current_stage and self.current_stage in self.stages:
            self.stages[self.current_stage].tokens_processed += count

    def get_stage_metrics(self, stage_name: str) -> Optional[StageMetrics]:
        """Get metrics for a specific stage."""
        return self.stages.get(stage_name)

    def get_summary(self) -> Dict[str, Any]:
        """Get summary of all stages."""
        total_gpu_joules = sum(s.gpu_energy_joules for s in self.stages.values())
        total_cpu_joules = sum(s.cpu_energy_joules for s in self.stages.values())
        total_energy_joules = sum(s.total_energy_joules for s in self.stages.values())
        total_tokens = sum(s.tokens_processed for s in self.stages.values())
        total_duration = sum(s.duration_seconds for s in self.stages.values())

        return {
            "experiment_id": self.experiment_id,
            "experiment_name": self.experiment_name,
            "start_time": self.start_time,
            "end_time": datetime.now().isoformat(),
            "total_duration_seconds": total_duration,
            "total_tokens_processed": total_tokens,
            "total_gpu_energy_joules": total_gpu_joules,
            "total_cpu_energy_joules": total_cpu_joules,
            "total_energy_joules": total_energy_joules,
            "total_energy_kwh": total_energy_joules / 3_600_000 if total_energy_joules > 0 else 0.0,
            "overall_joules_per_token": total_energy_joules / total_tokens if total_tokens > 0 else 0.0,
            "overall_tokens_per_second": total_tokens / total_duration if total_duration > 0 else 0.0,
            "stages": {name: metrics.to_dict() for name, metrics in self.stages.items()},
        }

    def save_summary(self, additional_metadata: Optional[Dict] = None) -> Path:
        """Save complete experiment summary to JSON."""
        summary = self.get_summary()
        if additional_metadata:
            summary["metadata"] = additional_metadata

        summary_file = self.energy_root / "experiment_summary.json"
        _write_json(summary_file, summary)
        for stage_id in self.stages.keys():
            safe_stage_id = _safe_filename(stage_id)
            stage_dir = self._stage_dirs.get(stage_id, self.stages_dir / safe_stage_id)
            stage_dir.mkdir(parents=True, exist_ok=True)
            _write_json(stage_dir / "experiment_summary.json", summary)
        print(f"[EnergyTracker] Summary saved to: {summary_file}")
        return summary_file

    def get_wandb_metrics(self, prefix: str = "energy") -> Dict[str, float]:
        metrics = {}
        for stage_name, stage in self.stages.items():
            stage_prefix = f"{prefix}/{stage_name}"
            metrics[f"{stage_prefix}/duration_sec"] = stage.duration_seconds
            metrics[f"{stage_prefix}/gpu_energy_joules"] = stage.gpu_energy_joules
            metrics[f"{stage_prefix}/cpu_energy_joules"] = stage.cpu_energy_joules
            metrics[f"{stage_prefix}/total_energy_joules"] = stage.total_energy_joules
            metrics[f"{stage_prefix}/gpu_avg_power_watts"] = stage.gpu_avg_power_watts
            metrics[f"{stage_prefix}/tokens_processed"] = stage.tokens_processed
            metrics[f"{stage_prefix}/joules_per_token"] = stage.joules_per_token
            metrics[f"{stage_prefix}/tokens_per_second"] = stage.tokens_per_second

        summary = self.get_summary()
        metrics[f"{prefix}/total_gpu_energy_joules"] = summary["total_gpu_energy_joules"]
        metrics[f"{prefix}/total_cpu_energy_joules"] = summary["total_cpu_energy_joules"]
        metrics[f"{prefix}/total_energy_joules"] = summary["total_energy_joules"]
        metrics[f"{prefix}/total_energy_kwh"] = summary["total_energy_kwh"]
        metrics[f"{prefix}/total_tokens"] = summary["total_tokens_processed"]
        metrics[f"{prefix}/overall_joules_per_token"] = summary["overall_joules_per_token"]

        return metrics

    def _integrate_power(
        self,
        readings: List[Dict[str, float]],
        start_t: float,
        end_t: float,
    ) -> Tuple[float, List[float], float, float]:
        """
        Integrate power readings over [start_t, end_t] and return:
        energy_joules, power_samples, avg_power, peak_power.
        """
        if not readings:
            return 0.0, [], 0.0, 0.0

        readings = sorted(readings, key=lambda r: r["timestamp"])
        power_samples = [r["total_power_w"] for r in readings]

        def power_at(t: float) -> float:
            p = readings[0]["total_power_w"]
            for r in readings:
                if r["timestamp"] <= t:
                    p = r["total_power_w"]
                else:
                    break
            return p

        window = [{"timestamp": start_t, "total_power_w": power_at(start_t)}]
        for r in readings:
            ts = r["timestamp"]
            if start_t < ts < end_t:
                window.append({"timestamp": ts, "total_power_w": r["total_power_w"]})
        window.append({"timestamp": end_t, "total_power_w": power_at(end_t)})

        e_j = 0.0
        for a, b in zip(window, window[1:]):
            p0, t0 = a["total_power_w"], a["timestamp"]
            p1, t1 = b["total_power_w"], b["timestamp"]
            dt = max(0.0, t1 - t0)
            e_j += 0.5 * (p0 + p1) * dt

        duration = max(0.0, end_t - start_t)
        avg_power = e_j / duration if duration > 0 else 0.0
        peak_power = max(power_samples) if power_samples else 0.0

        return e_j, power_samples, avg_power, peak_power

    def _read_codecarbon_metrics(
        self, codecarbon_dir: Path, project_name: str
    ) -> Optional[Dict[str, float]]:
        """
        Read the last CodeCarbon row for this project_name and return
        energy metrics in kWh.

        Returns a dict containing:
            - energy_consumed_kwh
            - cpu_energy_kwh
            - gpu_energy_kwh
            - ram_energy_kwh
        or None if no matching row is found.
        """
        emissions_csv = codecarbon_dir / "emissions.csv"
        if not emissions_csv.exists():
            candidates = [p for p in codecarbon_dir.glob("*.csv") if p.name.startswith("emissions")]
            if candidates:
                emissions_csv = sorted(candidates)[-1]
            else:
                return None

        for _ in range(5):
            try:
                last_match = None
                with open(emissions_csv, newline="") as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        if row and row.get("project_name") == project_name:
                            last_match = row
                if last_match:
                    def _get(name: str) -> float:
                        val = last_match.get(name, "")
                        return float(val) if val not in (None, "") else 0.0

                    # Per CodeCarbon docs, these are all in kWh
                    return {
                        "energy_consumed_kwh": _get("energy_consumed"),
                        "cpu_energy_kwh": _get("cpu_energy"),
                        "gpu_energy_kwh": _get("gpu_energy"),
                        "ram_energy_kwh": _get("ram_energy"),
                    }
            except Exception:
                # File may still be flushing; retry
                pass
            time.sleep(0.1)

        return None

    # Backwards-compatible helper if anything else still calls the old name
    def _read_codecarbon_energy_kwh(self, codecarbon_dir: Path, project_name: str) -> float:
        metrics = self._read_codecarbon_metrics(codecarbon_dir, project_name)
        return metrics["energy_consumed_kwh"] if metrics is not None else 0.0

    def _cleanup(self):
        """Cleanup resources on exit."""
        if self.current_stage is not None:
            try:
                self.end_stage()
            except Exception:
                pass

        if self._nvml_poller is not None:
            try:
                self._nvml_poller.stop()
            except Exception:
                pass

        if self._codecarbon_tracker is not None:
            try:
                self._codecarbon_tracker.stop()
            except Exception:
                pass
