import numpy as np
from scipy.stats import multivariate_normal

class ConformalFunctionalBand:
    """
    Split conformal bands for functional data with local scaling (Di Quigiovanni),
    using *test-time* weighting as in Proposition 7.3 (weights normalized over calib+test).
    Assumes arrays shaped (batch, space).
    """

    def __init__(self, alpha=0.1, scale_mode="stdev", side="two", eps=1e-8):
        if side not in ("two", "upper", "lower"):
            raise ValueError("side must be 'two', 'upper', or 'lower'")
        if scale_mode not in ("stdev", "mad", "none"):
            raise ValueError("scale_mode must be 'stdev', 'mad', or 'none'")
        self.alpha = float(alpha)
        self.scale_mode = scale_mode
        self.side = side
        self.eps = float(eps)

        # Fitted state
        self.sigma_ = None                 # (space,)
        self.scores_sorted_ = None         # (n_cal,)
        self.w_sorted_ = None              # (n_cal,) calib weights in same scale
        self.cumw_sorted_ = None           # (n_cal,) cumulative weights
        self.sumw_calib_ = None            # scalar
        self.logr_calib_sorted_ = None     # (n_cal,) log q/p for calib (sorted)
        self.logw_scale_ = None            # scalar used to stabilize exp()

    # ---------- helpers ----------

    @staticmethod
    def _safe_scale(residuals, mode, eps):
        if mode == "none":
            sigma = np.ones(residuals.shape[1], dtype=float)
        elif mode == "stdev":
            sigma = residuals.std(axis=0, ddof=1)
        else:  # "mad"
            med = np.median(residuals, axis=0)
            mad = np.median(np.abs(residuals - med), axis=0)
            sigma = 1.4826 * mad
        sigma = np.asarray(sigma, dtype=float)
        sigma[sigma < eps] = eps
        return sigma

    @staticmethod
    def _sup_standardized_scores(Y, Yhat, sigma, eps):
        return np.max(np.abs((Y - Yhat) / (sigma + eps)), axis=1)

    @staticmethod
    def _gaussian_log_ratio(X, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=False):
        p = multivariate_normal(mean=np.asarray(mu_p), cov=np.asarray(Sigma_p), allow_singular=allow_singular)
        q = multivariate_normal(mean=np.asarray(mu_q), cov=np.asarray(Sigma_q), allow_singular=allow_singular)
        return q.logpdf(np.asarray(X)) - p.logpdf(np.asarray(X))

    # ---------- calibration ----------

    def calibrate_with_gaussian_shift(self, Y_cal, Yhat_cal, X_cal,
                                      mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=False):
        """
        Calibrate using calib curves and exact Gaussian q/p weights from X_cal.
        Y_cal, Yhat_cal: (n_cal, d)
        X_cal: (n_cal, k)
        """
        Y = np.asarray(Y_cal, float)
        Yh = np.asarray(Yhat_cal, float)
        if Y.shape != Yh.shape or Y.ndim != 2:
            raise ValueError("Y_cal and Yhat_cal must have same shape (n, d)")
        if X_cal.shape[0] != Y.shape[0]:
            raise ValueError("X_cal and Y_cal must have same n")

        # local scale profile (per-space)
        residuals = Y - Yh
        sigma = self._safe_scale(residuals, self.scale_mode, self.eps)

        # scores for calib curves
        scores = self._sup_standardized_scores(Y, Yh, sigma, self.eps)

        # unnormalized log-weights (log q/p) for calib covariates
        logr = self._gaussian_log_ratio(X_cal, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=allow_singular)

        # sort by score once; keep everything sorted consistently
        order = np.argsort(scores)
        scores_sorted = scores[order]
        logr_sorted = logr[order]

        # choose a global log-scale to avoid overflow; store it so test weights use the same scale
        logw_scale = float(np.max(logr_sorted))
        w_sorted = np.exp(logr_sorted - logw_scale)   # same positive scaling for all calib points

        cumw_sorted = np.cumsum(w_sorted)
        sumw_calib = float(cumw_sorted[-1])

        # persist
        self.sigma_ = sigma
        self.scores_sorted_ = scores_sorted
        self.w_sorted_ = w_sorted
        self.cumw_sorted_ = cumw_sorted
        self.sumw_calib_ = sumw_calib
        self.logr_calib_sorted_ = logr_sorted
        self.logw_scale_ = logw_scale
        return self

    # ---------- bands at test time (uses test weight in normalization) ----------

    def band_with_gaussian_shift(self, Yhat, X_test, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=False):
        """
        Test-time bands using Proposition 7.3: normalize by (sum calib weights + test weight)
        and include the +∞ atom for the test point. Returns (lo, hi[, w_test, q]).
        """
        if self.sigma_ is None:
            raise RuntimeError("call calibrate_with_gaussian_shift() first")

        # --- shapes ---
        Yh = np.asarray(Yhat, float)
        squeeze = False
        if Yh.ndim == 1:
            Yh = Yh[None, :]
            squeeze = True
        if Yh.shape[1] != self.sigma_.shape[0]:
            raise ValueError("Yhat second dimension must match calibrated space dimension")

        X = np.asarray(X_test, float)
        if X.ndim == 1:
            X = X[None, :]

        # --- test weights on SAME scale as calibration ---
        log_r_test = self._gaussian_log_ratio(
            X, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=allow_singular
        )  # log(q/p)(x_{n+1})
        w_test = np.exp(log_r_test - self.logw_scale_)  # (m,)

        # target mass for each test point: (1-α)*(sum_w_calib + w_test)
        target_mass = (1.0 - self.alpha) * (self.sumw_calib_ + w_test)

        # if target exceeds total calib mass, the quantile is +∞ (because of w_test δ_{+∞})
        needs_inf = target_mass > self.sumw_calib_

        # find weighted quantile on calibration CDF for the rest
        tm_clipped = np.minimum(target_mass, self.sumw_calib_)
        idx = np.searchsorted(self.cumw_sorted_, tm_clipped, side="left")
        idx = np.clip(idx, 0, len(self.scores_sorted_) - 1)

        # assemble q and allow +∞
        q = self.scores_sorted_[idx]
        q = q.astype(float, copy=True)
        q[needs_inf] = np.inf  # enforce the +∞ atom

        # widths and bands
        width = q[:, None] * self.sigma_[None, :]

        if self.side == "two":
            lo = Yh - width
            hi = Yh + width
        elif self.side == "upper":
            lo = -np.inf * np.ones_like(Yh)
            hi = Yh + width
        else:  # "lower"
            lo = Yh - width
            hi = np.inf * np.ones_like(Yh)

        if squeeze:
            return lo[0], hi[0], w_test[0], q[0]
        return lo, hi, w_test, q


    def get_coverage(self, Ytrue, Yhat, X_test, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=False):
        """
        Get empirical coverage and bandwidth on test set.
        Ytrue, Yhat: (n_test, d)
        X_test: (n_test, k)
        """
        lo, hi, _, _ = self.band_with_gaussian_shift(
            Yhat, X_test, mu_p, Sigma_p, mu_q, Sigma_q, allow_singular=allow_singular
        )
        Y = np.asarray(Ytrue, float)
        if Y.shape != Yhat.shape or Y.ndim != 2:
            raise ValueError("Ytrue and Yhat must have same shape (n, d)")
        if lo.shape != hi.shape or lo.shape != Y.shape:
            raise ValueError("band shape must match Ytrue/Yhat shape")
        inside = (Y >= lo) & (Y <= hi)
        isinf = (np.isinf(lo) | np.isinf(hi)).any(axis=1)
        covered = np.sum(inside, axis=1) >= Y.shape[1] - 3
        covered = covered & ~isinf  # if band is infinite, count as uncovered
        if np.sum(~isinf) == 0:
            coverage = 0.0
        else:
            coverage = np.sum(covered) / np.sum(~isinf)
        return coverage, covered, hi - lo
