# utils/policy_logger.py
import os, csv
from typing import Dict, Any, List

class PolicyLogger:
    def __init__(self, csv_path: str):
        self.csv_path = csv_path
        os.makedirs(os.path.dirname(csv_path), exist_ok=True)
        self._header_written = os.path.exists(csv_path) and os.path.getsize(csv_path) > 0
        self._header: List[str] = []
        self._step = 0
        self._epoch = 0
        self._meta: Dict[str, Any] = {}

    def set_meta(self, **meta):
        self._meta = dict(meta or {})

    def set_epoch(self, epoch: int):
        self._epoch = int(epoch)

    def log(self, feature_row: Dict[str, float], probs, strengths, choice, layer_id: int, transforms: List[str]):
        """
        probs:     [B, T]
        strengths: [B, C, T]
        choice:    [B, T] one-hot
        """
        import torch
        p_mean   = probs.detach().mean(dim=0).cpu().numpy().tolist()        # [T]
        ch_mean  = choice.detach().float().mean(dim=0).cpu().numpy().tolist()  # [T]
        s_mean   = strengths.detach().mean(dim=0).mean(dim=0).cpu().numpy().tolist()  # [T]

        row = {
            "epoch": self._epoch,
            "step": self._step,
            "layer": layer_id,
            **self._meta,
            **feature_row,
        }
        for i, name in enumerate(transforms):
            row[f"p_{name}"]     = float(p_mean[i])
            row[f"alpha_{name}"] = float(s_mean[i])
            row[f"sel_{name}"]   = float(ch_mean[i])

        self._write(row)
        self._step += 1

    def _write(self, row: Dict[str, Any]):
        if not self._header_written:
            self._header = list(row.keys())
            with open(self.csv_path, "w", newline="") as f:
                w = csv.DictWriter(f, fieldnames=self._header); w.writeheader()
            self._header_written = True
        with open(self.csv_path, "a", newline="") as f:
            out = {k: row.get(k, "") for k in self._header}
            csv.DictWriter(f, fieldnames=self._header).writerow(out)
