import numpy as np

class ConformalFunctionalBandNaive:
    """
    Split conformal bands for functional data with local scaling,
    using *equal weights* (no density weighting).
    Assumes arrays shaped (batch, space).
    """

    def __init__(self, alpha=0.1, scale_mode="stdev", side="two", eps=1e-8):
        if side not in ("two", "upper", "lower"):
            raise ValueError("side must be 'two', 'upper', or 'lower'")
        if scale_mode not in ("stdev", "mad", "none"):
            raise ValueError("scale_mode must be 'stdev', 'mad', or 'none'")
        self.alpha = float(alpha)
        self.scale_mode = scale_mode
        self.side = side
        self.eps = float(eps)

        # Fitted state
        self.sigma_ = None             # (space,)
        self.scores_sorted_ = None     # (n_cal,) ascending
        self.n_cal_ = None             # scalar

    # ---------- helpers ----------

    @staticmethod
    def _safe_scale(residuals, mode, eps):
        if mode == "none":
            sigma = np.ones(residuals.shape[1], dtype=float)
        elif mode == "stdev":
            sigma = residuals.std(axis=0, ddof=1)
        else:  # "mad"
            med = np.median(residuals, axis=0)
            mad = np.median(np.abs(residuals - med), axis=0)
            sigma = 1.4826 * mad
        sigma = np.asarray(sigma, dtype=float)
        sigma[sigma < eps] = eps
        return sigma

    @staticmethod
    def _sup_standardized_scores(Y, Yhat, sigma, eps):
        return np.max(np.abs((Y - Yhat) / (sigma + eps)), axis=1)

    # ---------- calibration ----------

    def calibrate(self, Y_cal, Yhat_cal):
        """
        Calibrate using calibration curves with *equal weights*.
        Y_cal, Yhat_cal: (n_cal, d)
        """
        Y = np.asarray(Y_cal, float)
        Yh = np.asarray(Yhat_cal, float)
        if Y.shape != Yh.shape or Y.ndim != 2:
            raise ValueError("Y_cal and Yhat_cal must have same shape (n, d)")

        # local scale profile (per-space)
        residuals = Y - Yh
        sigma = self._safe_scale(residuals, self.scale_mode, self.eps)

        # scores for calib curves
        scores = self._sup_standardized_scores(Y, Yh, sigma, self.eps)

        # sort once
        order = np.argsort(scores)
        scores_sorted = scores[order]

        # persist
        self.sigma_ = sigma
        self.scores_sorted_ = scores_sorted
        self.n_cal_ = int(scores_sorted.shape[0])
        return self

    # ---------- bands at test time (equal weights with +∞ atom) ----------

    def band(self, Yhat):
        """
        Test-time bands with equal weighting.
        Uses k = ceil((n_cal + 1)*(1 - alpha)) - 1 (0-indexed).
        If k == n_cal, the quantile is +∞ (test-point atom).
        Returns (lo, hi, q), broadcasting over batch of Yhat.
        """
        if self.sigma_ is None:
            raise RuntimeError("call calibrate() first")

        Yh = np.asarray(Yhat, float)
        squeeze = False
        if Yh.ndim == 1:
            Yh = Yh[None, :]
            squeeze = True
        if Yh.shape[1] != self.sigma_.shape[0]:
            raise ValueError("Yhat second dimension must match calibrated space dimension")

        n = self.n_cal_
        # target mass on (n + 1)-point CDF (n calib + test atom)
        target = (n + 1) * (1.0 - self.alpha)
        k = int(np.ceil(target) - 1)  # 0-indexed

        if k >= n:
            # +∞ atom is hit
            q = np.inf
        else:
            q = float(self.scores_sorted_[k])

        width = q * self.sigma_[None, :]  # broadcasts over batch

        if self.side == "two":
            lo = Yh - width
            hi = Yh + width
        elif self.side == "upper":
            lo = -np.inf * np.ones_like(Yh)
            hi = Yh + width
        else:  # "lower"
            lo = Yh - width
            hi = np.inf * np.ones_like(Yh)

        if squeeze:
            return lo[0], hi[0], q
        return lo, hi, q

    def get_coverage(self, Ytrue, Yhat):
        """
        Get empirical coverage and bandwidth on test set.
        Ytrue, Yhat: (n_test, d)
        Returns: (coverage, coverage_noninf, widths)
        """
        lo, hi, _ = self.band(Yhat)

        Y = np.asarray(Ytrue, float)
        if Y.shape != Yhat.shape or Y.ndim != 2:
            raise ValueError("Ytrue and Yhat must have same shape (n, d)")
        if lo.shape != hi.shape or lo.shape != Y.shape:
            raise ValueError("band shape must match Ytrue/Yhat shape")

        inside = (Y >= lo) & (Y <= hi)
        isinf = (np.isinf(lo) | np.isinf(hi)).any(axis=1)

        # Example criterion: all coords covered (or allow a few misses if you want, adjust here)
        covered_vec = np.sum(inside, axis=1) >= Y.shape[1]

        # Coverage over all, counting infinite bands as *uncovered* (like the weighted version)
        covered_overall = np.mean(covered_vec & ~isinf) if np.any(~isinf) else 0.0

        # For convenience, return the same value twice (to mirror the original signature shape),
        # plus the per-point widths.
        return covered_overall, (covered_vec & ~isinf), (hi - lo)
