import numpy as np
import math
import torch
from typing import Callable, Optional, Dict, Any, Tuple, Union
import pandas as pd
import arviz as az
import time
from scipy.optimize import linear_sum_assignment


def init_polar(p, k, eps=0.05, *, device="cpu", dtype=torch.float64):
    X = torch.randn(p, k, dtype=dtype, device=device)
    U, _, Vh = torch.linalg.svd(X, full_matrices=False)
    Q = U @ Vh
    Z = eps * torch.randn(p, k, dtype=dtype, device=device)
    return (Q + Z).reshape(-1)

def Gamma_from_B(B: torch.Tensor) -> torch.Tensor:
        U, _, Vh = torch.linalg.svd(B, full_matrices=False)
        return U @ Vh


class DualAveragingStepSize:
    def __init__(
        self,
        init_step_size: float,
        target_accept: float = 0.8,
        gamma: float = 0.05,
        t0: float = 10.0,
        kappa: float = 0.75,
    ):
        self.target = float(target_accept)
        self.gamma = float(gamma)
        self.t0 = float(t0)
        self.kappa = float(kappa)

        # Stan-style choice
        self.mu = math.log(10.0 * init_step_size)

        self.hbar = 0.0
        self.log_eps = math.log(init_step_size)
        self.log_eps_bar = math.log(init_step_size)
        self.t = 0

    def update(self, accept_prob: float) -> float:
        self.t += 1
        t = self.t

        eta = 1.0 / (t + self.t0)
        self.hbar = (1.0 - eta) * self.hbar + eta * (self.target - accept_prob)

        self.log_eps = self.mu - (math.sqrt(t) / self.gamma) * self.hbar

        w = t ** (-self.kappa)
        self.log_eps_bar = w * self.log_eps + (1.0 - w) * self.log_eps_bar

        return math.exp(self.log_eps)

    def final_step_size(self) -> float:
        return math.exp(self.log_eps_bar)


@torch.no_grad()
def make_mass_diag_from_var(
    var: torch.Tensor,
    jitter: float = 1e-3,
    min_m: float = 1e-6,
    max_m: float = 1e6,
) -> torch.Tensor:
    mass = var + jitter
    mass = torch.clamp(mass, min=min_m, max=max_m)
    return mass

class RunningDiagVar:

    def __init__(self, d: int, device=None, dtype=None):
        self.n = 0
        self.mean = torch.zeros(d, device=device, dtype=dtype)
        self.M2 = torch.zeros(d, device=device, dtype=dtype)

    @torch.no_grad()
    def update(self, x: torch.Tensor):
        # x: [d]
        self.n += 1
        delta = x - self.mean
        self.mean += delta / self.n
        delta2 = x - self.mean
        self.M2 += delta * delta2

    @torch.no_grad()
    def var(self) -> torch.Tensor:
        if self.n < 2:
            return torch.ones_like(self.mean)
        return self.M2 / (self.n - 1)



def align_perm_sign_hungarian(Q: torch.Tensor, Qref: torch.Tensor) -> torch.Tensor:
    # Q,Qref: (p,r)
    C = Qref.T @ Q                      # (r,r)
    Cabs = torch.abs(C).detach().cpu().numpy()
    row_ind, col_ind = linear_sum_assignment(-Cabs)  # maximize |C|

    Qp = Q[:, col_ind]                  # permute columns
    s = torch.sign(torch.diag(Qref.T @ Qp))
    s[s == 0] = 1.0
    return Qp * s

def procrustes_align(Q: torch.Tensor, Qref: torch.Tensor) -> torch.Tensor:
    M = Qref.T @ Q
    U, _, Vh = torch.linalg.svd(M)
    R = U @ Vh
    return Q @ R.T

def align_to_ref_hungarian_then_procrustes(Q: torch.Tensor, Qref: torch.Tensor) -> torch.Tensor:
    Q1 = align_perm_sign_hungarian(Q, Qref)
    Q2 = procrustes_align(Q1, Qref)
    return Q2

def align_filled_for_rhat(Q_post: torch.Tensor) -> torch.Tensor:
 
    C, T, p, r = Q_post.shape
    Qref = Q_post[0, 0]
    out = torch.empty_like(Q_post)
    for c in range(C):
        for t in range(T):
            out[c, t] = align_to_ref_hungarian_then_procrustes(Q_post[c, t], Qref)
    return out


def compute_ess_safe(
    samples,
    *,
    max_lag=None,
    var_eps=1e-12,
    min_unique=5,
    round_unique=12,
    method="geyer",   # "positive" | "geyer"
):
  
    x = np.asarray(samples, dtype=float)
    n = x.size
    info = {"n": int(n), "method": method}

    if n < 3:
        return float(n), np.ones(1), {"reason": "too-short", **info}

    if max_lag is None:
        max_lag = min(n - 1, 500)

    # ---------- Guardrails ----------
    var = np.var(x, ddof=1)
    if var < var_eps:
        return 1.0, np.ones(1), {"reason": "near-constant", "var": float(var), **info}

    uniq = np.unique(np.round(x, round_unique)).size
    if uniq < min_unique:
        return 1.0, np.ones(1), {"reason": f"low-unique({uniq})", **info}

    # ---------- ACF ----------
    x0 = x - x.mean()
    denom = np.mean(x0 * x0)
    if denom < var_eps:
        return 1.0, np.ones(1), {"reason": "near-constant-after-centering", **info}

    acf = np.empty(max_lag)
    acf[0] = 1.0
    for k in range(1, max_lag):
        acf[k] = np.mean(x0[:-k] * x0[k:]) / denom

    # ---------- IACT (tau) ----------
    if method == "positive":
        tau = 1.0
        cutoff = 0
        for k in range(1, max_lag):
            if acf[k] <= 0:
                cutoff = k
                break
            tau += 2.0 * acf[k]
        else:
            cutoff = max_lag
        info["cutoff_lag"] = int(cutoff)

    elif method == "geyer":
        gammas = []
        for t in range(1, (max_lag // 2) + 1):
            i, j = 2*t - 1, 2*t
            if j >= max_lag:
                break
            g = acf[i] + acf[j]
            if g <= 0:
                break
            gammas.append(g)

        if len(gammas) == 0:
            tau = 1.0
            info["cutoff_pairs"] = 0
        else:
            # enforce monotone decrease
            gmono = np.array(gammas)
            for t in range(1, len(gmono)):
                if gmono[t] > gmono[t-1]:
                    gmono[t] = gmono[t-1]

            tau = 1.0 + 2.0 * np.sum(gmono)
            info["cutoff_pairs"] = int(len(gmono))
    else:
        raise ValueError("method must be 'positive' or 'geyer'")


    tau = max(1.0, float(tau))
    ess = n / tau
    ess = float(np.clip(ess, 1.0, n))

    info["tau"] = tau
    info["ess_raw"] = float(n / tau)

    return ess, acf, info

def ess_1d_series_ct(x_ct, *, method="geyer", max_lag=None, **kwargs):
   
    x_ct = np.asarray(x_ct, dtype=float)
    if x_ct.ndim != 2:
        raise ValueError(f"expected (C,T), got {x_ct.shape}")

    C, T = x_ct.shape
    # chain별 ESS -> 평균 (혹은 median)으로 요약
    ess_list = []
    for c in range(C):
        ess, _, _ = compute_ess_safe(
            x_ct[c],
            method=method,
            max_lag=max_lag,
            **kwargs
        )
        ess_list.append(ess)
    return np.asarray(ess_list, dtype=float)  # (C,)

def compute_ess_and_eff(self, idata, out):
    ds = idata.posterior  # xarray Dataset

   
    var_names = list(ds.data_vars.keys())
    if len(var_names) == 0:
        raise ValueError("idata.posterior has no variables.")

 
    ess_per_var = {}
    ess_chainwise_all = []

    for vn in var_names:
        x = ds[vn].values  # expected (C,T)
        if x.ndim != 2:
         
            x = np.asarray(x).reshape(self.C, -1)
        ess_c = ess_1d_series_ct(x, method="geyer")  # (C,)
        ess_per_var[vn] = ess_c
        ess_chainwise_all.append(ess_c)

 
    ess_elements = np.concatenate(ess_chainwise_all, axis=0)

 
    az_sum = az.summary(idata, round_to=4) if (self.compute_summary and self.C >= 2) else None

    eff_vals = ess_elements / self.nprop_total
    runtime_sum_sec = np.sum(self.runtime_sec)
    sec_vals = ess_elements / runtime_sum_sec

    ess_summary = {
        "median": float(np.nanmedian(ess_elements)),
        "mean":   float(np.nanmean(ess_elements)),
        "min":    float(np.nanmin(ess_elements)),
        "max":    float(np.nanmax(ess_elements)),
        "median_eff": float(np.nanmedian(eff_vals)),
        "mean_eff":   float(np.nanmean(eff_vals)),
        "min_eff":    float(np.nanmin(eff_vals)),
        "max_eff":    float(np.nanmax(eff_vals)),
        "median_sec": float(np.nanmedian(sec_vals)),
        "mean_sec":   float(np.nanmean(sec_vals)),
        "min_sec":    float(np.nanmin(sec_vals)),
        "max_sec":    float(np.nanmax(sec_vals)),
    }

    outputs = out.copy()
    outputs.update({
        "n_proposals_per_chain": self.nprop_c,
        "n_proposals_total": self.nprop_total,
        "ess_elements": ess_elements,         # 전체 (var,chain) flatten
        "ess_by_var": ess_per_var,            # 변수별 chainwise ESS (dict)
        "eff_elements": eff_vals,
        "ess_summary": ess_summary,
        "az_summary": az_sum,
        "C": self.C,
        "T": self.T,
        "p": self.p,
        "r": self.r,
        "runtime_total_sec": self.runtime_total_sec,
        "var_names": var_names,
    })

    return outputs

class Evaluator:
    def __init__(self, Gamma, angle_aggregate: str = "sum_sq", geodesic_ref: str = "projection_mean", geodesic_aggregate: str = "sum_sq", 
    increment_aggregate: str = "sum_sq", 
    directional_dist: str = "mvmf",
    clamp_eps: float = 1e-12,
    n_proposals: Optional[Union[int, np.ndarray]] = None,
    runtime_sec: Optional[Union[int, np.ndarray]] = None,
    A_ref: np.ndarray = None,
    compute_summary: bool = True):
        
  

        self.angle_aggregate = str(angle_aggregate)        # "sum_sq" | "mean" | "max"
        self.geodesic_ref = str(geodesic_ref)   # "projection_mean" | "first"
        self.geodesic_aggregate = str(geodesic_aggregate)      # "sum_sq" | "mean" | "max"
        self.increment_aggregate = str(increment_aggregate)      # "sum_sq" | "mean" | "max"
        self.directional_dist = str(directional_dist)
        self.A_ref =  A_ref
        self.compute_summary = bool(compute_summary)

        self.clamp_eps = clamp_eps
        self.n_proposals = n_proposals
        self.runtime_sec = runtime_sec
   
        if self.increment_aggregate not in ("sum_sq", "mean", "max"):
            raise ValueError('increment_aggregate must be one of {"sum_sq","mean","max"}')
        if self.geodesic_ref not in ("projection_mean", "first"):
            raise ValueError('geodesic_ref must be one of {"projection_mean","first"}')
        if self.geodesic_aggregate not in ("sum_sq", "mean", "max"):
            raise ValueError('geodesic_aggregate must be one of {"sum_sq","mean","max"}')
        if self.directional_dist not in ("mvmf", "mbh"):
            raise ValueError('geodesic_aggregate must be one of {"mvmf", "mbh")}')
        
        
        self.G, self.C, self.T, self.p, self.r = self._check_gamma(Gamma)
        self.nprop_c = self._as_nprop_array(self.n_proposals, self.C)
        self.nprop_total = int(self.nprop_c.sum())
        self.runtime_sec = self._as_runtime_array(runtime_sec, self.C)
        self.runtime_total_sec = float(np.sum(self.runtime_sec))

    
    @staticmethod
    def _as_nprop_array(n_proposals, C: int) -> np.ndarray:
       
        if n_proposals is None:
            raise ValueError("n_proposals must be provided for efficiency (ESS/proposal).")
        if np.isscalar(n_proposals):
            arr = np.full(C, int(n_proposals), dtype=int)
        else:
            arr = np.asarray(n_proposals, dtype=int)
            if arr.shape != (C,):
                raise ValueError(f"n_proposals must be int or shape (C,), got {arr.shape}")
        if np.any(arr <= 0):
            raise ValueError("All n_proposals must be >= 1.")
        return arr
    
    @staticmethod
    def _as_runtime_array(runtime_sec, C: int) -> np.ndarray:
        if runtime_sec is None:
            raise ValueError("runtime_sec must be provided.")
        if np.isscalar(runtime_sec):
            arr = np.full(C, float(runtime_sec), dtype=float)
        else:
            arr = np.asarray(runtime_sec, dtype=float)
            if arr.shape != (C,):
                raise ValueError(f"runtime_sec must be scalar or shape (C,), got {arr.shape}")
        # runtime이 0이면 ess/sec가 inf 되니 최소 eps
        arr = np.maximum(arr, 1e-12)
        return arr
    
    @staticmethod
    def _to_numpy(x):
        
        try:
            import torch
            if isinstance(x, torch.Tensor):
                return x.detach().cpu().numpy().astype(np.float64)
        except Exception:
            pass
        return np.asarray(x, dtype=np.float64)
    
    @staticmethod
    def _clamp(x, lo=-1.0, hi=1.0):
        return np.minimum(np.maximum(x, lo), hi)
    
    def _check_gamma(self, Gamma):
        
        G = self._to_numpy(Gamma)

        # auto lift single-chain
        if G.ndim == 3:
            G = G[None, ...]
        elif G.ndim != 4:
            raise ValueError("Gamma must have shape (T,p,r) or (C,T,p,r).")

        C, T, p, r = G.shape

        if C < 1:
            raise ValueError("Need at least one chain (C >= 1).")
        if T < 1:
            raise ValueError("Need at least one draw per chain (T >= 1).")
        if p < 1 or r < 1:
            raise ValueError("p and r must be >= 1.")
        if r > p:
            raise ValueError("Require r <= p.")

        return G, C, T, p, r


   
    def _principal_angles(self, G1, G2) -> np.ndarray:
        G1 = np.asarray(G1, dtype=np.float64)
        G2 = np.asarray(G2, dtype=np.float64)

     
        if not (np.isfinite(G1).all() and np.isfinite(G2).all()):
            raise ValueError("Non-finite in Gamma (NaN/Inf)")

        Q1, _ = np.linalg.qr(G1)
        Q2, _ = np.linalg.qr(G2)

        M = Q1.T @ Q2  # (r,r)

        try:
            s = np.linalg.svd(M, compute_uv=False)
        except np.linalg.LinAlgError:
         
            C = M.T @ M
            w = np.linalg.eigvalsh(C)
            w = np.clip(w, 0.0, None)
            s = np.sqrt(w)[::-1]  # descending

        s = np.clip(s, 0.0, 1.0)
        s = np.clip(s, 0.0 + self.clamp_eps, 1.0 - self.clamp_eps)

        return np.arccos(s)

    def projection(self):
        
        # P: (C,T,p,p)
        P = np.einsum("ctpr,ctqr->ctpq", self.G, self.G)

        iu = np.triu_indices(self.p)
        P_ut = P[:, :, iu[0], iu[1]]  # (C,T,n_entries)
        n_entries = P_ut.shape[-1]

        idata_proj = az.from_dict(
            posterior={
                f"P_{i}": P_ut[:, :, i]
                for i in range(n_entries)
            },
            coords={
                "chain": np.arange(self.C),
                "draw": np.arange(self.T),
            },
            dims={
                f"P_{i}": ["chain", "draw"]
                for i in range(n_entries)
            }
        )
    
        out = {
            "method": "projection",
            "n_entries": int(P_ut.shape[-1]),
            "idata": idata_proj
        }
        return out
        
        
    def principal_angle_increments(self, aggregate="sum_sq"):
      
        
        g = np.empty((self.C, self.T - 1), dtype=np.float64)

        for c in range(self.C):
            for t in range(1, self.T):
                thetas = self._principal_angles(self.G[c, t], self.G[c, t - 1])
                if aggregate == "sum_sq":
                    g[c, t - 1] = float(np.sum(thetas**2))
                elif aggregate == "mean":
                    g[c, t - 1] = float(np.mean(thetas))
                elif aggregate == "max":
                    g[c, t - 1] = float(np.max(thetas))
                else:
                    raise ValueError("aggregate must be one of {'sum_sq','mean','max'}")

        idata_inc = az.from_dict(
            posterior={
                "pai": g
            },
            coords={
                "chain": np.arange(self.C),
                "draw": np.arange(g.shape[1]),
            },
            dims={
                "pai": ["chain", "draw"]
            }
        )
        out = {
            "method": "principal_angle_increments",
            "aggregate": aggregate,
            "idata": idata_inc
        }
        return out

        
    def geodesic_distance(self, ref="projection_mean", aggregate="sum_sq"):
       
        if ref == "first":
            Gref = self.G[0, 0]
        elif ref == "projection_mean":
            # P: (C,T,p,p)
            P = np.einsum("ctpr,ctqr->ctpq", self.G, self.G)
            Pbar = P.mean(axis=(0, 1))  # average over chain and draw -> (p,p)
            w, V = np.linalg.eigh(Pbar)
            idx = np.argsort(w)[::-1][:self.r]
            Gref = V[:, idx]
        else:
            raise ValueError("ref must be one of {'projection_mean','first'}")

        d = np.empty((self.C, self.T), dtype=np.float64)

        for c in range(self.C):
            for t in range(self.T):
                thetas = self._principal_angles(self.G[c, t], Gref)
                if aggregate == "sum_sq":
                    d[c, t] = float(np.sum(thetas**2))
                elif aggregate == "mean":
                    d[c, t] = float(np.mean(thetas))
                elif aggregate == "max":
                    d[c, t] = float(np.max(thetas))
                else:
                    raise ValueError("aggregate must be one of {'sum_sq','mean','max'}")

        idata_geo = az.from_dict(
            posterior={
                "gd": d
            },
            coords={
                "chain": np.arange(self.C),
                "draw": np.arange(d.shape[1]),
            },
            dims={
                "gd": ["chain", "draw"]
            }
        )
        
        out = {
            "method": "geodesic_distance",
            "ref": ref,
            "aggregate": aggregate,
            "idata": idata_geo
        }
        
        return out
    
    def element_efficiency(self):
       
    
      
        idata_elem = az.from_dict(
            posterior={f"G_{i}_{j}": self.G[:, :, i, j] for i in range(self.p) for j in range(self.r)},
            coords={"chain": np.arange(self.C), "draw": np.arange(self.T)},
            dims={f"G_{i}_{j}": ["chain", "draw"] for i in range(self.p) for j in range(self.r)},
        )

        out = {
            "method": "element_efficiency",
            "n_elements": int(self.p * self.r),
            "idata": idata_elem
        }
        return out
    
    def sufficient_statistic(self, dist="mvmf"):
   
        if self.A_ref is None:
            return None

        G = self.G
        A = self.A_ref

        if G.ndim == 3:
            # -----------------
            # G: (T, p, r)
            # -----------------
            if dist == "mvmf":
                # <A, G_t>
                stat = np.einsum("pr,tpr->t", A, G)   # (T,)

            elif dist == "mbh":
                if self.Q is None:
                    raise ValueError("Q is required for dist='mbh'")
                AQ = A * self.Q                       # (p, r)
                stat = np.einsum("tpr,pr->t", G, AQ) # (T,)

            else:
                raise ValueError(dist)

        elif G.ndim == 4:
            # -----------------
            # G: (C, T, p, r)
            # -----------------
            if dist == "mvmf":
                # <A, G_{c,t}>
                stat = np.einsum("pr,ctpr->ct", A, G)   # (C, T)

            elif dist == "mbh":
                if self.Q is None:
                    raise ValueError("Q is required for dist='mbh'")
                AQ = A * self.Q                          # (p, r)
                stat = np.einsum("ctpr,pr->ct", G, AQ)  # (C, T)

            else:
                raise ValueError(dist)

        else:
            raise ValueError(f"Unsupported G shape: {G.shape}")

        if stat.ndim == 1:
            stat = stat[None, :]

        idata_ss = az.from_dict(
            posterior={"ss": stat},
            coords={"chain": np.arange(stat.shape[0]), "draw": np.arange(stat.shape[1])},
            dims={"ss": ["chain", "draw"]},
        )
        
        return {"method": "sufficient_stat", "dist": dist, "idata": idata_ss}

        
    def compute_ess_and_eff(self, idata, out):
        ds = idata.posterior  
        var_names = list(ds.data_vars.keys())
        if len(var_names) == 0:
            raise ValueError("idata.posterior has no variables.")

        ess_per_var = {}
        ess_chainwise_all = []

        for vn in var_names:
            x = ds[vn].values 
            if x.ndim != 2:
            
                x = np.asarray(x).reshape(self.C, -1)
            ess_c = ess_1d_series_ct(x, method="geyer")  # (C,)
            ess_per_var[vn] = ess_c
            ess_chainwise_all.append(ess_c)

        # (n_vars, C) -> (n_vars*C,)
        ess_elements = np.concatenate(ess_chainwise_all, axis=0)

        az_sum = az.summary(idata, round_to=4) if (self.compute_summary and self.C >= 2) else None

        eff_vals = ess_elements / self.nprop_total
        runtime_sum_sec = np.sum(self.runtime_sec)
        sec_vals = ess_elements / runtime_sum_sec

        ess_summary = {
            "median": float(np.nanmedian(ess_elements)),
            "mean":   float(np.nanmean(ess_elements)),
            "min":    float(np.nanmin(ess_elements)),
            "max":    float(np.nanmax(ess_elements)),
            "median_eff": float(np.nanmedian(eff_vals)),
            "mean_eff":   float(np.nanmean(eff_vals)),
            "min_eff":    float(np.nanmin(eff_vals)),
            "max_eff":    float(np.nanmax(eff_vals)),
            "median_sec": float(np.nanmedian(sec_vals)),
            "mean_sec":   float(np.nanmean(sec_vals)),
            "min_sec":    float(np.nanmin(sec_vals)),
            "max_sec":    float(np.nanmax(sec_vals)),
        }

        outputs = out.copy()
        outputs.update({
            "n_proposals_per_chain": self.nprop_c,
            "n_proposals_total": self.nprop_total,
            "ess_elements": ess_elements,         
            "ess_by_var": ess_per_var,            
            "eff_elements": eff_vals,
            "ess_summary": ess_summary,
            "az_summary": az_sum,
            "C": self.C,
            "T": self.T,
            "p": self.p,
            "r": self.r,
            "runtime_total_sec": self.runtime_total_sec,
            "var_names": var_names,
        })

        return outputs


    def evaluate(self):
       
        out_proj = self.projection()
        out_inc  = self.principal_angle_increments(aggregate=self.increment_aggregate)
        out_geo  = self.geodesic_distance(ref=self.geodesic_ref, aggregate=self.geodesic_aggregate)

        results = {
            "projection": self.compute_ess_and_eff(out_proj["idata"], out_proj),
            "principal_angle_increments": self.compute_ess_and_eff(out_inc["idata"], out_inc),
            "geodesic_distance": self.compute_ess_and_eff(out_geo["idata"], out_geo),
        }
        
        
        out_suff = None  

        if (getattr(self, "directional_dist", None) is not None) and (getattr(self, "A_ref", None) is not None):
            out_suff = self.sufficient_statistic(dist=self.directional_dist)
     
        if out_suff is not None and isinstance(out_suff, dict) and ("idata" in out_suff):
            results["sufficient_stat"] = self.compute_ess_and_eff(out_suff["idata"], out_suff)

        return results
 

def _evaluate_chainwise_budget_throughput(
    Q_list,
    total_steps_per_chain,
    runtime_sec_per_chain,
    *,
    angle_aggregate="sum_sq",
    geodesic_ref="projection_mean",
    geodesic_aggregate="sum_sq",
    increment_aggregate="sum_sq",
    directional_dist = 'mvmf',
    clamp_eps=1e-12,
    ess_pick="mean",  # "median" or "mean",
    A_ref = None
):
   
    C = len(Q_list)
    nprop = np.asarray(total_steps_per_chain, dtype=int)
    rt = np.asarray(runtime_sec_per_chain, dtype=float)

    metrics = ["projection", "principal_angle_increments", "geodesic_distance", "sufficient_stat"]

    rows = []
    ess_sum = {m: 0.0 for m in metrics}
    valid_chain_mask = np.zeros(C, dtype=bool)

    for c in range(C):
        Gc = Q_list[c]
        Tc = int(getattr(Gc, "shape", [0])[0])

        row = {
            "chain": c,
            "budget_T": Tc,
            "budget_total_steps": int(nprop[c]),
            "budget_runtime_sec": float(rt[c]),
        }

       
        if (Gc is None) or (Tc < 4):
            for m in metrics:
                row[f"budget_{m}_ess_{ess_pick}"] = np.nan
            rows.append(row)
            continue

        ev = Evaluator(
            Gc,  # (T,p,r) -> lift to (1,T,p,r)
            angle_aggregate=angle_aggregate,
            geodesic_ref=geodesic_ref,
            geodesic_aggregate=geodesic_aggregate,
            increment_aggregate=increment_aggregate,
            clamp_eps=clamp_eps,
            n_proposals=int(nprop[c]),
            runtime_sec=float(rt[c]),
            directional_dist=directional_dist,
            A_ref = A_ref
        )
        out = ev.evaluate()

        valid_chain_mask[c] = True

        for m in metrics:
            val = float(out[m]["ess_summary"][ess_pick])
            row[f"budget_{m}_ess_{ess_pick}"] = val
            if np.isfinite(val):
                ess_sum[m] += val

        rows.append(row)

    df_chain = pd.DataFrame(rows)

    total_steps = float(np.sum(nprop))
    total_time = float(np.sum(rt))
    n_valid = int(np.sum(valid_chain_mask))

    ag = {
        "budget_total_steps_sum": int(total_steps),
        "budget_runtime_sum_sec": float(total_time),
        "budget_n_chains_valid_for_ess": int(n_valid),
    }

    for m in metrics:
        ag[f"budget_{m}_ess_sum_chain_{ess_pick}"] = float(ess_sum[m])
        ag[f"budget_{m}_ess_per_iter_total"] = float(ess_sum[m] / max(total_steps, 1.0))
        ag[f"budget_{m}_ess_per_sec_total"] = float(ess_sum[m] / max(total_time, 1e-12))

    df_aggre = pd.DataFrame([ag])
    return df_chain, df_aggre


def _budget_chainwise_ess_only(
    Q_list,
    steps_per_chain,
    rt_per_chain,
    *,
    angle_aggregate="sum_sq",
    geodesic_ref="projection_mean",
    geodesic_aggregate="sum_sq",
    increment_aggregate="sum_sq",
    directional_dist="mvmf",
    clamp_eps=1e-12,
    ess_pick="mean",   # "mean" or "median"
    A_ref=None,
):

    rows = []
    C = len(Q_list)
    steps = np.asarray(steps_per_chain, dtype=int)
    rt    = np.asarray(rt_per_chain, dtype=float)

    for c in range(C):
        Gc = Q_list[c]
        if Gc is None:
            rows.append({
                "chain": c,
                "budget_draws": 0,
                "budget_total_steps": int(steps[c]),
                "budget_runtime_sec": float(rt[c]),
            })
            continue

        if isinstance(Gc, torch.Tensor):
            Gc = Gc.detach().cpu().numpy()
        else:
            Gc = np.asarray(Gc)

        T = int(Gc.shape[0])

        row = {
            "chain": c,
            "budget_draws": T,
            "budget_total_steps": int(steps[c]),
            "budget_runtime_sec": float(rt[c]),
        }

        if T < 4:
     
            rows.append(row)
            continue

        ev = Evaluator(
            Gc,  # (T,p,r)
            angle_aggregate=angle_aggregate,
            geodesic_ref=geodesic_ref,
            geodesic_aggregate=geodesic_aggregate,
            increment_aggregate=increment_aggregate,
            clamp_eps=clamp_eps,
            directional_dist=directional_dist,
            n_proposals=int(steps[c]),
            runtime_sec=float(rt[c]),
            compute_summary=False,  
            A_ref=A_ref,
        )
        out = ev.evaluate()  

        for metric_name, metric_out in out.items():
            s = metric_out["ess_summary"]
            row[f"budget_{metric_name}_ess_{ess_pick}"] = float(s[ess_pick])
            row[f"budget_{metric_name}_ess_per_iter_{ess_pick}"] = float(s[f"{ess_pick}_eff"])
            row[f"budget_{metric_name}_ess_per_sec_{ess_pick}"] = float(s[f"{ess_pick}_sec"])

        rows.append(row)

    df_chain = pd.DataFrame(rows)

    ag = {"budget_n_chains": int(C)}
    for col in df_chain.columns:
        if col.startswith("budget_") and col not in ("budget_draws", "budget_total_steps", "budget_runtime_sec"):
            ag[f"{col}_mean_over_chains"] = float(np.nanmean(df_chain[col].to_numpy(dtype=float)))
    df_ag = pd.DataFrame([ag])

    return df_chain, df_ag


def report_one_rep_chains_to_dfs(
    samples: dict,
    meta: dict,
    budget: dict = None,
    *,
    gamma_key: str = "Q",
    angle_aggregate: str = "sum_sq",
    geodesic_ref: str = "projection_mean",
    geodesic_aggregate: str = "sum_sq",
    increment_aggregate: str = "sum_sq",
    directional_dist: str = "mvmf",
    clamp_eps: float = 1e-12,
    ess_pick: str = "mean",   # "mean" or "median"
    A_ref=None,
):
   

    Gamma_filled = samples[gamma_key]  # (C,T,p,r)
    if isinstance(Gamma_filled, torch.Tensor):
        Gamma_filled = Gamma_filled.detach().cpu().numpy()
    else:
        Gamma_filled = np.asarray(Gamma_filled)

    C = int(Gamma_filled.shape[0])
    T_filled = int(Gamma_filled.shape[1])

    ev_filled = Evaluator(
        Gamma_filled,
        angle_aggregate=angle_aggregate,
        geodesic_ref=geodesic_ref,
        geodesic_aggregate=geodesic_aggregate,
        increment_aggregate=increment_aggregate,
        directional_dist=directional_dist,
        clamp_eps=clamp_eps,
        n_proposals=np.asarray(meta["total_steps_per_chain"], dtype=int),
        runtime_sec=np.asarray(meta["runtime_sec_per_chain"], dtype=float),
        compute_summary=True,   # ★ filled는 Rhat 필요
        A_ref=A_ref,
    )
    out_filled = ev_filled.evaluate()

  
    df_meta_chain = pd.DataFrame({
        "chain": np.arange(C),
        "filled_n_draws": np.full(C, T_filled, dtype=int),
        "filled_total_steps": np.asarray(meta["total_steps_per_chain"], dtype=int),
        "filled_runtime_sec": np.asarray(meta["runtime_sec_per_chain"], dtype=float),
        "filled_acc_rate": np.asarray(meta.get("acc_rate_per_chain", np.full(C, np.nan)), dtype=float),
        "filled_kept_rate": np.asarray(meta.get("kept_rate_per_chain", np.full(C, np.nan)), dtype=float),
    })

    df_budget_ag = pd.DataFrame([{}])  

    if budget is not None:
       
        df_budget_chain, df_budget_ag = _budget_chainwise_ess_only(
            Q_list=budget["Q_list"],
            steps_per_chain=budget["total_steps_per_chain"],
            rt_per_chain=budget["runtime_sec_per_chain"],
            angle_aggregate=angle_aggregate,
            geodesic_ref=geodesic_ref,
            geodesic_aggregate=geodesic_aggregate,
            increment_aggregate=increment_aggregate,
            directional_dist=directional_dist,
            clamp_eps=clamp_eps,
            ess_pick=ess_pick,
            A_ref=A_ref,
        )

        extra_cols = {}
        if "acc_rate_per_chain" in budget:
            extra_cols["budget_acc_rate"] = np.asarray(budget["acc_rate_per_chain"], dtype=float)
        if "kept_rate_per_chain" in budget:
            extra_cols["budget_kept_rate"] = np.asarray(budget["kept_rate_per_chain"], dtype=float)
        if "n_kept_actual_per_chain" in budget:
            extra_cols["budget_n_kept_actual"] = np.asarray(budget["n_kept_actual_per_chain"], dtype=int)
        if "status_per_chain" in budget:
            extra_cols["budget_status"] = np.asarray(budget["status_per_chain"], dtype=object)
        if "error_per_chain" in budget:
            extra_cols["budget_error"] = np.asarray(budget["error_per_chain"], dtype=object)

        if len(extra_cols) > 0:
            df_extra = pd.DataFrame({"chain": np.arange(C), **extra_cols})
            df_budget_chain = df_budget_chain.merge(df_extra, on="chain", how="left")

  
        df_meta_chain = df_meta_chain.merge(df_budget_chain, on="chain", how="left")

  
    df_meta_aggre = pd.DataFrame([{
        "n_chain": int(C),
        "filled_n_draws": int(T_filled),
        "filled_runtime_sum_sec": float(np.sum(df_meta_chain["filled_runtime_sec"].to_numpy(dtype=float))),
        "filled_runtime_mean_sec": float(np.mean(df_meta_chain["filled_runtime_sec"].to_numpy(dtype=float))),
        "filled_total_steps_sum": int(np.sum(df_meta_chain["filled_total_steps"].to_numpy(dtype=int))),
        "filled_total_steps_mean": float(np.mean(df_meta_chain["filled_total_steps"].to_numpy(dtype=float))),
        "filled_acc_rate_mean": float(np.nanmean(df_meta_chain["filled_acc_rate"].to_numpy(dtype=float))),
        "filled_kept_rate_mean": float(np.nanmean(df_meta_chain["filled_kept_rate"].to_numpy(dtype=float))),
    }])

    
    rhat_row = {}
    for metric_name, metric_out in out_filled.items():
        az_sum = metric_out.get("az_summary", None)
        if az_sum is None or ("r_hat" not in az_sum.columns):
            rhat_row[f"filled_{metric_name}_rhat_max"] = np.nan
            rhat_row[f"filled_{metric_name}_rhat_mean"] = np.nan
        else:
            rhat_row[f"filled_{metric_name}_rhat_max"] = float(az_sum["r_hat"].max())
            rhat_row[f"filled_{metric_name}_rhat_mean"] = float(az_sum["r_hat"].mean())
    df_rhat = pd.DataFrame([rhat_row])

 
    ess_row = {}
    for metric_name, metric_out in out_filled.items():
        s = metric_out["ess_summary"]
        ess_row[f"filled_{metric_name}_ess_median"] = float(s["median"])
        ess_row[f"filled_{metric_name}_ess_mean"] = float(s["mean"])
        ess_row[f"filled_{metric_name}_ess_per_iter_median"] = float(s["median_eff"])
        ess_row[f"filled_{metric_name}_ess_per_iter_mean"] = float(s["mean_eff"])
        ess_row[f"filled_{metric_name}_ess_per_sec_median"] = float(s["median_sec"])
        ess_row[f"filled_{metric_name}_ess_per_sec_mean"] = float(s["mean_sec"])

  
    if budget is not None and len(df_budget_ag.columns) > 0:
        ess_row.update(df_budget_ag.iloc[0].to_dict())

    df_ess_eff = pd.DataFrame([ess_row])

    return df_meta_chain, df_meta_aggre, df_rhat, df_ess_eff




def nanmin(x, dim):
    mask = ~torch.isnan(x)
    x2 = torch.where(mask, x, torch.full_like(x, float("inf")))
    return x2.min(dim=dim).values

def nanmax(x, dim):
    mask = ~torch.isnan(x)
    x2 = torch.where(mask, x, torch.full_like(x, float("-inf")))
    return x2.max(dim=dim).values

def nanmean(x, dim):
    mask = ~torch.isnan(x)
    x2 = torch.where(mask, x, torch.zeros_like(x))
    cnt = mask.sum(dim=dim).clamp(min=1)
    return x2.sum(dim=dim) / cnt