import abc
import dataclasses
import threading
import typing as tp
import warnings
from contextlib import contextmanager

import numpy as np

from generative_prediction_sets import utils


@dataclasses.dataclass
class CalibrationContext:
  """Context object that can be used to pass information to predictors during calibration.

  This is a singleton class that manages thread-local context. Use CalibrationContext.get()
  to get the current context and CalibrationContext.set() to create a new context.

  Usage:
    # Get current context
    context = CalibrationContext.get()

    # Create new context
    with CalibrationContext.set(alpha=0.1):
      # Operations here will have access to alpha=0.1
      predictor.predict(X)  # Values recorded to context's tape
  """

  alpha: float | None = None
  """Given alpha, we must produce sets with coverage >= 1 - alpha."""

  # Thread-local storage
  _context = threading.local()
  _DEFAULT = None  # Will be set after class definition

  @classmethod
  def get(cls) -> "CalibrationContext":
    """Get the current calibration context, using default if none exists."""
    context = getattr(cls._context, "current", None)
    if context is None:
      if cls._DEFAULT is None:
        cls._DEFAULT = cls()
      warnings.warn(
        "No explicit calibration context set. Using default context. "
        "Consider using 'with CalibrationContext.set(alpha=...)' to set context explicitly.",
        UserWarning,
        stacklevel=2,
      )
      return cls._DEFAULT
    return context

  @classmethod
  @contextmanager
  def set(
    cls,
    alpha: float | None = None,
  ):
    """Create a new context with the given parameters.

    Args:
        alpha: Target coverage level

    Usage:
        with CalibrationContext.set(alpha=0.1):
            predictor.predict(X)
    """
    old_context = getattr(cls._context, "current", None)
    context = cls(
      alpha=alpha,
    )
    cls._context.current = context

    try:
      yield context
    finally:
      if old_context is not None:
        cls._context.current = old_context
      else:
        delattr(cls._context, "current")


# Set default context after class definition
CalibrationContext._DEFAULT = CalibrationContext()


class Predictor(abc.ABC):
  @abc.abstractmethod
  def fit(self, X, y): ...

  @abc.abstractmethod
  def predict_with_aux(
    self, X: np.ndarray
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]: ...

  def __call__(self, X: np.ndarray) -> np.ndarray:
    """Compute predictions and record auxiliary outputs to the current context.

    Args:
        X: Input features

    Returns:
        Model predictions
    """
    preds, _ = self.predict_with_aux(X)
    return preds


class ScoreFunction(tp.Protocol):
  def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: ...


@dataclasses.dataclass
class ConformalResult:
  """Base class for all conformal calibration results."""

  test_preds: np.ndarray
  """Predictions from the predictor on the test set."""
  test_aux: dict[str, np.ndarray]
  """Auxillary outputs from the predictor on the test set."""

  cal_scores: np.ndarray
  """Conformity scores from the calibration set."""
  cal_preds: np.ndarray
  """Predictions from the predictor on the calibration set."""
  cal_aux: dict[str, np.ndarray]
  """Auxillary outputs from the predictor on the calibration set."""

  cis: np.ndarray
  """Confidence intervals for the test set."""
  quantile: float
  """Quantile of the calibration non-conformity scores."""


ResultType = tp.TypeVar("ResultType", bound=ConformalResult, covariant=True)
PredictorType = tp.TypeVar("PredictorType", bound=Predictor, contravariant=True)


class NonConformityScore(abc.ABC):
  def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return self.compute(y_true, y_pred)

  @abc.abstractmethod
  def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    """Compute the conformity score for a given pair of true and predicted values."""
    ...

  @abc.abstractmethod
  def invert(self, y_pred: np.ndarray, quantile: float) -> np.ndarray:
    """Given the y_pred and quantile, return {y: S(y, y_pred) <= quantile}.

    Note that this is the inverse of the conformity score,
    and is defined as the set of y values that are less than or equal to the quantile.
    """
    ...

  def quantile(self, scores: np.ndarray, *, alpha: float):
    return utils.conformal_quantile(scores, alpha=alpha)[0]


class ConformityScore(NonConformityScore):
  def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return self.compute(y_true, y_pred)

  def quantile(self, scores: np.ndarray, *, alpha: float):
    return super().quantile(scores, alpha=1 - alpha)


@dataclasses.dataclass
class Calibrator(abc.ABC, tp.Generic[ResultType, PredictorType]):
  """Base class for all conformal calibrators.

  All calibrators should:
  1. Take a predictor in their constructor
  2. Implement fit() to fit the predictor if needed
  3. Implement compute() to perform calibration with the given alpha
  """

  predictor: PredictorType

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """Fit the predictor if needed."""
    self.predictor.fit(X, y)

  @abc.abstractmethod
  def compute(
    self,
    X_cal: np.ndarray,
    y_cal: np.ndarray,
    X_test: np.ndarray,
    *,
    alpha: float,
  ) -> ResultType:
    """Perform conformal calibration.

    Args:
        X_cal: Calibration features
        y_cal: Calibration labels
        X_test: Test features to predict on
        alpha: Target coverage level
        tape: Optional tape to record auxiliary values

    Returns:
        A ConformalResult containing calibration scores, predictions, and confidence intervals
    """
    ...

  def __call__(
    self,
    X_cal: np.ndarray,
    y_cal: np.ndarray,
    X_test: np.ndarray,
    *,
    alpha: float,
  ) -> ResultType:
    """Perform conformal calibration with proper context management.

    This method handles setting up the calibration context before calling compute().
    Calibrator implementations should override compute() rather than this method.
    """
    with CalibrationContext.set(alpha=alpha):
      return self.compute(
        X_cal,
        y_cal,
        X_test,
        alpha=alpha,
      )
