# crc_custom.py  （将此类替换原实现）
import numpy as np
from typing import Optional

class CustomConformalRiskController:

    def __init__(self, alpha: float, c_all: float, B: Optional[float] = None, random_state: int = 42):
        if not (0.0 < alpha < 1.0):
            raise ValueError("alpha must be in (0,1)")
        self.alpha = float(alpha)
        self.c_all = float(c_all)
        self.random_state = int(random_state)

        self.B = 1.0 if B is None else float(B)

        self.lambdas = None
        self.lambda_idx_ = None
        self.lambda_ = None
        self.empirical_risk_curve_ = None
        self.empirical_size_curve_ = None

    def _make_lambda_grid(self, num_lam: int = 2000) -> np.ndarray:
        num_lam = max(2, int(num_lam))
        return np.linspace(0.0, 1.0, num_lam)

    def _validate_inputs(self, scores: np.ndarray, labels: np.ndarray, costs: np.ndarray):
        if not (isinstance(scores, np.ndarray) and scores.ndim == 2):
            raise ValueError("scores must be numpy array shape (N, K)")
        N, K = scores.shape
        if not (isinstance(labels, np.ndarray) and labels.shape == (N, K)):
            raise ValueError("labels must be numpy array shape (N, K)")
        if not (isinstance(costs, np.ndarray) and costs.ndim == 1 and costs.shape[0] == K):
            raise ValueError("costs must be 1D numpy array length K")
        if not np.all(np.logical_or(labels == 0, labels == 1)):
            raise ValueError("labels must be binary {0,1}")
        if np.any(costs < 0.0):
            raise ValueError("costs must be non-negative")

    def _build_risk_tables_vectorized(
        self,
        scores: np.ndarray,
        labels: np.ndarray,
        costs: np.ndarray,
        gate_probs: Optional[np.ndarray] = None,
        gate_thr: Optional[float] = None,
        gate_idx: Optional[int] = None
    ):
        N, K = scores.shape
        L = len(self.lambdas)
        lambs = np.asarray(self.lambdas, dtype=np.float32)  # ascending
        thresholds = lambs.reshape(1, 1, L)  # (1,1,L)

        costs1 = costs[None, :, None]                      # (1,K,1)
        c_all = float(self.c_all)                                   # 标量（建议传入 sum(phi_costs)）

        scores3 = scores[:, :, None]                                # (N,K,1)
        labels3 = labels[:, :, None].astype(np.float32)             # (N,K,1)

        cand_all = (scores3 >= thresholds)                          # (N,K,L) bool

        def _risk_from_cand(cand_bool: np.ndarray) -> tuple[np.ndarray, np.ndarray]:

            neg_mask = (labels3 == 0)

            # 分母：每样本的负标签数，用于“率”的归一化（N, 1）
            neg_cnt = neg_mask.sum(axis=1).astype(np.float32)
            neg_cnt = np.clip(neg_cnt, 1.0, None)  # 防止除0；若样本没有负类，FPR 定义为 0

            # 分子：候选集中负标签的个数（N, L）
            sel_neg = (cand_bool & neg_mask).sum(axis=1).astype(np.float32)

            # FPR：每 λ 的样本级假阳率（N, L）
            risk_table = sel_neg / neg_cnt  # 广播到 (N, L)

            # 集合大小（用于可视化/调试）
            size_table = cand_bool.sum(axis=1).astype(np.float32)
            return risk_table.astype(np.float32), size_table

        # 无 gate：与之前相同，只是负例项变了
        if gate_probs is None or gate_thr is None or gate_idx is None:
            return _risk_from_cand(cand_all)

        # 有 gate：对 gate 样本固定候选为 {gate_idx}，其余样本用 cand_all
        gate_probs = np.asarray(gate_probs).reshape(-1)
        if gate_probs.shape[0] != N:
            raise ValueError("gate_probs length must equal N")
        if not (0.0 <= gate_thr <= 1.0):
            raise ValueError("gate_thr must be in [0,1]")
        if not (0 <= gate_idx < K):
            raise ValueError("gate_idx must be a valid model index")

        gate_mask = (gate_probs >= float(gate_thr))                 # (N,)
        cand = np.zeros_like(cand_all, dtype=bool)
        if gate_mask.any():
            cand[gate_mask, gate_idx, :] = True
        if (~gate_mask).any():
            cand[~gate_mask, :, :] = cand_all[~gate_mask, :, :]

        return _risk_from_cand(cand)


    def _select_lambda(self, risk_table: np.ndarray, size_table: np.ndarray):
        """
        Choose the minimal lambda index satisfying CRC bound:
            lhs(λ) = (n/(n+1)) * r_hat(λ) + B/(n+1)  <= alpha
        where r_hat(λ) = mean_i risk_table[i, λ].
        Return chosen lambda value and index.
        """
        n = risk_table.shape[0]
        if n == 0:
            idx = len(self.lambdas) - 1
            return float(self.lambdas[idx]), int(idx)

        r_hat = risk_table.mean(axis=0)   # (L,)
        s_hat = size_table.mean(axis=0)   # (L,)
        self.empirical_risk_curve_ = r_hat.copy()
        self.empirical_size_curve_ = s_hat.copy()

        # if B not set, set safe upper bound: max(sum(costs), c_all * K)
        if self.B is None:
            # approximate B from data: B = max per-sample worst possible sum
            # but we don't have costs here in this function; caller must set B prior to fit if desired.
            # As fallback use max observed risk across lambdas + small margin
            self.B = float(r_hat.max() * 1.1 + 1e-6)

        lhs = (n / float(n + 1.0)) * r_hat + (self.B / float(n + 1.0))
        # valid λ indices are those with lhs <= alpha
        valid = np.where(lhs <= self.alpha)[0]
        if valid.size == 0:
            # none satisfies -> return most conservative (largest lambda index)
            idx = len(self.lambdas) - 1
            return float(self.lambdas[idx]), int(idx)

        # we want the *smallest* lambda (i.e., minimal lambda index because lambdas sorted ascending)
        # that still satisfies the risk bound — this yields the largest candidate sets allowed.
        best_idx = int(valid.min())
        return float(self.lambdas[best_idx]), best_idx

    def fit(self, scores: np.ndarray, labels: np.ndarray, costs: np.ndarray,
            num_lam: int = 2000, vectorized: bool = True,
            gate_probs: Optional[np.ndarray] = None, gate_thr: Optional[float] = None, gate_idx: Optional[int] = None):
        """
        Fit CRC on calibration set. Optionally simulate a fixed gate decision
        by providing gate_probs (N,), gate_thr (float), gate_idx (int).
        """
        self._validate_inputs(scores, labels, costs)
        self.lambdas = self._make_lambda_grid(num_lam)

        # compute B if not provided
        if self.B is None:
            K = costs.shape[0]
            self.B = float(max(costs.sum(), self.c_all * K))

        if vectorized:
            risk_table, size_table = self._build_risk_tables_vectorized(
                scores, labels, costs,
                gate_probs=gate_probs, gate_thr=gate_thr, gate_idx=gate_idx
            )
        else:
            # fallback looped version with gate simulation - keep as before or implement similarly
            raise NotImplementedError("Non-vectorized fit with gate is not implemented in this helper.")

        # optional: normalize by B into [0,1] (recommended)
        if self.B <= 0:
            raise ValueError("Computed B must be positive.")
        self._B_orig = float(self.B)
        risk_table = risk_table.astype(np.float64) / float(self.B)
        self.B = 1.0

        self.lambda_, self.lambda_idx_ = self._select_lambda(risk_table, size_table)
        print(f"[CRC] fit done. chosen lambda_idx={self.lambda_idx_}, lambda={self.lambda_:.6f}")
        return self


    def predict_with_gate(self, scores: np.ndarray, gate_probs: np.ndarray, gate_thr: float, gate_idx: int, costs: np.ndarray):
        if self.lambda_ is None:
            raise RuntimeError("Not fitted")
        N, K = scores.shape
        routed = np.full(N, int(np.argmin(costs)), dtype=np.int32)
        thr = float(self.lambda_)

        other_idx = np.array([j for j in range(K) if j != gate_idx], dtype=np.int32)
        cheapest_other = int(other_idx[np.argmin(costs[other_idx])]) if other_idx.size else int(gate_idx)

        for i in range(N):
            if gate_probs[i] >= gate_thr:
                routed[i] = int(gate_idx)
                continue
            cand = np.where(scores[i] >= thr)[0]
            cand = cand[cand != gate_idx]
            if cand.size == 0:
                # routed[i] = cheapest_other
                routed[i] = int(other_idx[np.argmax(scores[i, other_idx])])
            else:
                sc = scores[i, cand].astype(np.float32)
                cc = costs[cand]
                # exp = sc * (1-sc) * cc + (1.0 - sc) * sc * (float(self.c_all) - cc)
                exp = cc
                routed[i] = int(cand[np.argmin(exp)])
        return routed
