"""Calibrators that conformalize using a sampling-based approach
in a classification space (e.g. for LLMs)"""

import dataclasses
import functools
import typing as tp

from loguru import logger
import numpy as np

from generative_prediction_sets.core import (
  ConformityScore,
  NonConformityScore,
  Predictor,
  Calibrator,
  ConformalResult,
)
from generative_prediction_sets.simulation_data import get_k_min


def create_padded_array(X, M):
  # Get the length of X
  N = len(X)

  # Initialize Y with zeros
  Y = np.zeros((N, M))

  # Fill in ones up to the specified positions
  for i in range(N):
    Y[i, : X[i]] = 1

  return Y


@dataclasses.dataclass
class ConformalCalibrator(Calibrator[ConformalResult, Predictor]):
  """A calibrator that uses conformity scores for calibration."""

  score: ConformityScore

  def compute(
    self, X_cal: np.ndarray, y_cal: np.ndarray, X_test: np.ndarray, *, alpha: float
  ) -> ConformalResult:
    cal_preds, cal_aux = self.predictor.predict_with_aux(X_cal)
    cal_scores = self.score(y_cal, cal_preds)
    quantile = self.score.quantile(cal_scores, alpha=1 - alpha)
    test_preds, test_aux = self.predictor.predict_with_aux(X_test)
    test_cis = self.score.invert(test_preds, quantile)

    return ConformalResult(
      cis=test_cis,
      test_preds=test_preds,
      test_aux=test_aux,
      cal_scores=cal_scores,
      cal_preds=cal_preds,
      cal_aux=cal_aux,
      quantile=quantile,
    )


_default_stopping_rule = functools.partial(get_k_min, no_success_value=np.inf)


@dataclasses.dataclass
class StoppingRuleCalibrator(Calibrator[ConformalResult, Predictor]):
  """A calibrator that reduces the sampling-based approach to a regression-based approach and back.

  Primarily, this produces a stopping rule for the sampling.

  Suppose we have input covariates X and a distribution P(Y | X).

  Let K be a stopping rule for the sampling.

  Given (X_1, X_2, ..., X_n), we can sample (Y_1, ..., Y_(K_i)) from P(Y | X_i) for
  each X_i, for some stopping rule K_i.

  Then for X_{n+1} we want to produce a stopping rule \hat{K} such
  that P\{K <= \hat{K}\} >= 1 - \alpha.

  For example, we can define K = inf{j: Y_j is admissible}.

  Then P\{\exists j <= \hat{K}: Y_j is admissible\} >= 1 - \alpha.

  """

  score: NonConformityScore
  stopping_rule: tp.Callable[[np.ndarray], np.ndarray] = _default_stopping_rule
  nonadmissible_handling: tp.Literal["inf", "k_max"] = "inf"
  force_abstention_scores_to_inf: bool = False
  k_max_offset: int = 1

  def compute(
    self, X_cal: np.ndarray, y_cal: np.ndarray, X_test: np.ndarray, *, alpha: float
  ) -> ConformalResult:
    K_MAX = y_cal.shape[1]
    # 1. Truncate (Y_1, ..., Y_M) to (Y_1, ..., Y_tau) using the stopping rule
    cal_true_ks = self.stopping_rule(y_cal)
    # 2. Predict the stopping rule for the test set
    cal_pred_ks, cal_aux = self.predictor.predict_with_aux(X_cal)
    # Some predictors return integer predictions, so we cast them to float
    # for handling NON_ADMISSIBLE_VALUE
    cal_pred_ks = cal_pred_ks.astype(float)
    # 3. The stopping rule returns an index between 0 and M-1, if the stopping rule
    # is satisfied. If the stopping rule is not satisfied, it returns np.inf.
    # This can either remain "inf"
    if self.nonadmissible_handling == "inf":
      NON_ADMISSIBLE_VALUE = np.inf
    else:
      # or we can set it to the last admissible index + offset
      NON_ADMISSIBLE_VALUE = y_cal.shape[1] + self.k_max_offset
    # We may use some other value than inf to represent non-admissible values,
    # and would affect the score quantile if both the base model and the stopping
    # rule predictor report NON_ADMISSIBLE_VALUE.
    cal_true_ks[cal_true_ks > K_MAX] = NON_ADMISSIBLE_VALUE
    # If predicted ks are greater than K_MAX, we set them to the NON_ADMISSIBLE_VALUE
    cal_pred_ks[cal_pred_ks > K_MAX] = NON_ADMISSIBLE_VALUE

    # Scores are computed using the NON_ADMISSIBLE_VALUE
    cal_scores = self.score(cal_true_ks, cal_pred_ks)
    if self.force_abstention_scores_to_inf:
      cal_scores[cal_true_ks == NON_ADMISSIBLE_VALUE] = np.inf

    # Compute the conformal quantile
    quantile = self.score.quantile(cal_scores, alpha=1 - alpha)
    # First invert the scores to get the upper bounds
    test_preds, test_aux = self.predictor.predict_with_aux(X_test)
    test_cis = self.score.invert(test_preds, quantile)
    test_ubs = test_cis[:, 1]
    # For predicted upper bounds greater than K_MAX, we set them to np.inf
    test_ubs[test_ubs > K_MAX] = np.inf

    # Now we convert this back to a boolean mask
    # that tells us which samples to keep
    # first we need to remove the infs
    test_ubs_infs = test_ubs == np.inf
    test_ubs[test_ubs_infs] = 0
    # Then we create a padded array of the upper bounds
    pred_ys = create_padded_array(test_ubs.astype(int), K_MAX)
    # Lastly, we restore the infs
    pred_ys[test_ubs_infs, :] = np.inf

    return ConformalResult(
      cis=pred_ys,
      cal_scores=cal_scores,
      quantile=quantile,
      test_preds=test_preds,
      test_aux=test_aux,
      cal_preds=cal_pred_ks,
      cal_aux=cal_aux,
    )


@dataclasses.dataclass
class FixedSamplingCalibrator(Calibrator[ConformalResult, Predictor]):
  """A calibrator that performs classification on a fixed number of samples.

  Suppose for each x, we sample (y_1, ..., y_M) for a fixed M.

  We can then imagine (y_1, ..., y_M) as a set of classes for an output,
  and perform classification on it."""

  score: NonConformityScore

  def compute(
    self, X_cal: np.ndarray, y_cal: np.ndarray, X_test: np.ndarray, *, alpha: float
  ) -> ConformalResult:
    cal_preds, cal_aux = self.predictor.predict_with_aux(X_cal)
    cal_scores = self.score(y_cal, cal_preds)
    # Set all scores where no admissible values are present to np.inf
    INF = -np.inf if isinstance(self.score, ConformityScore) else np.inf
    cal_scores[np.sum(y_cal, axis=1) == 0] = INF
    quantile = self.score.quantile(cal_scores, alpha=1 - alpha)
    test_preds, test_aux = self.predictor.predict_with_aux(X_test)
    if quantile == np.inf or quantile == -np.inf:
      logger.debug(f"Quantile is inf for {X_cal.shape[0]} samples")
      test_cis = np.full((X_test.shape[0], y_cal.shape[1]), np.inf)
    else:
      test_cis = self.score.invert(test_preds, quantile)

    return ConformalResult(
      cis=test_cis,
      cal_scores=cal_scores,
      quantile=quantile,
      test_preds=test_preds,
      test_aux=test_aux,
      cal_preds=cal_preds,
      cal_aux=cal_aux,
    )
