from generative_prediction_sets.core import ConformityScore, NonConformityScore
import dataclasses
import numpy as np
import typing as tp

RegressionScoreSide = tp.Literal["both", "lower", "upper"]


@dataclasses.dataclass
class CQRScore(NonConformityScore):
  side: RegressionScoreSide = "both"
  range: tuple[float, float] = (-np.inf, np.inf)

  def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    # y_pred needs to be a 2D array, where each row is [q_lo, q_hi]
    if y_pred.ndim != 2 or y_pred.shape[1] != 2:
      raise ValueError("y_pred must be a 2D array with 2 columns")
    q_lo, q_hi = y_pred[:, 0], y_pred[:, 1]
    if self.side == "both":
      return np.where(q_lo > y_true, q_lo - y_true, y_true - q_hi)
    elif self.side == "lower":
      return np.where(q_lo > y_true, q_lo - y_true, 0)
    elif self.side == "upper":
      return np.where(y_true > q_hi, y_true - q_hi, 0)

  def invert(self, y_pred: np.ndarray, quantile: float) -> np.ndarray:
    q_low, q_hi = y_pred[:, 0], y_pred[:, 1]
    inverted_preds = np.full(y_pred.shape, 0).astype(float)
    if self.side == "both":
      inverted_preds[:, 0] = q_low - quantile
      inverted_preds[:, 1] = q_hi + quantile
    elif self.side == "lower":
      inverted_preds[:, 0] = q_low - quantile
      inverted_preds[:, 1] = self.range[1]
    elif self.side == "upper":
      inverted_preds[:, 0] = self.range[0]
      inverted_preds[:, 1] = q_hi + quantile
    return inverted_preds


@dataclasses.dataclass
class AbsoluteResidualScore(NonConformityScore):
  def score(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return np.abs(y_true - y_pred)

  def invert(self, y_pred: np.ndarray, quantile: float) -> np.ndarray:
    return np.column_stack([y_pred - quantile, y_pred + quantile])
