import dataclasses
import typing as tp
from typing import Optional

import numpy as np
import scipy.stats
import torch
import upath
from loguru import logger
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression

from generative_prediction_sets.core import CalibrationContext, Predictor


class SklearnPredictor(Predictor):
  """A predictor that wraps a scikit-learn model and provides confidence scores."""

  def __init__(
    self,
    model: BaseEstimator,
    predict_fn: tp.Literal["predict", "predict_proba"] = "predict",
    prefit: bool = False,
  ):
    self.model = model
    self.predict_fn = predict_fn
    self._is_fit = False
    self.prefit = prefit

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """Fit the underlying sklearn model."""
    self.model.fit(X, y)
    self._is_fit = True

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    """Make predictions and return auxiliary confidence information."""
    if not self._is_fit and not self.prefit:
      raise RuntimeError("Model is not fit.")
    predict_method = getattr(self.model, self.predict_fn)
    preds = predict_method(X)
    return preds, {}


def p_map_estimate(y: np.ndarray) -> np.ndarray:
  return (np.sum(y, axis=1) + 1) / (y.shape[1] + 2)


@dataclasses.dataclass
class BernoulliPPredictor(Predictor):
  """A predictor for Bernoulli trials that estimates success probabilities."""

  base_estimator: Optional[BaseEstimator] = dataclasses.field(default=None, init=True)
  estimator_cls: tp.Type[BaseEstimator] = LinearRegression
  p_estimate: tp.Callable[[np.ndarray], np.ndarray] = dataclasses.field(
    default_factory=lambda: p_map_estimate
  )
  estimator: BaseEstimator = dataclasses.field(init=False)

  def __post_init__(self):
    self.estimator = (
      self.base_estimator if self.base_estimator is not None else self.estimator_cls()
    )

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """Fit the model using Bernoulli trial data."""
    if y.ndim == 1:
      raise ValueError("y must be a 2D array of bernoulli trials")
    p_estimates = self.p_estimate(y)
    self.estimator.fit(X, p_estimates)

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    """Predict probabilities with uncertainty information."""
    ps = self.estimator.predict(X)
    clipped_ps = np.clip(ps, 1e-12, 1 - 1e-12)
    return clipped_ps, {"raw_ps": ps}


@dataclasses.dataclass
class ConstantPredictor(Predictor):
  value: float = 0
  ndim: int | None = None

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    pass

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    if self.ndim is None:
      preds = np.full(X.shape[0], self.value)
    else:
      preds = np.full((X.shape[0], self.ndim), self.value)
    return preds, {}


@dataclasses.dataclass
class ConstantCIPredictor(Predictor):
  value: float = 0

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    pass

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    preds = np.full((X.shape[0], 2), self.value)
    return preds, {}


@dataclasses.dataclass
class GeometricQuantilePredictor(Predictor):
  """A predictor that predicts geometric quantiles using the calibration context's alpha."""

  base_predictor: Predictor = dataclasses.field(default_factory=BernoulliPPredictor)
  clip: bool = False
  k_max: Optional[int] = None

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """Fit the base predictor and set maximum k value."""
    self.k_max = y.shape[1] + 1
    return self.base_predictor.fit(X, y)

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    """Predict quantiles with uncertainty information."""
    base_preds = self.base_predictor(X)
    context = CalibrationContext.get()
    if context is None:
      raise ValueError(
        "No calibration context found. GeometricQuantilePredictor requires a calibration context."
      )
    alpha = context.alpha
    preds = scipy.stats.geom.ppf(1 - alpha, base_preds)
    if self.clip:
      preds = np.clip(preds, 1, self.k_max)

    return np.column_stack([np.full(preds.shape[0], -np.inf), preds]), {
      "p0_preds": base_preds,
      "raw_preds": preds,
    }


@dataclasses.dataclass
class QuantilePredictor(Predictor):
  """A predictor that produces quantile intervals."""

  q_lo_predictor: Predictor
  q_hi_predictor: Predictor
  side: tp.Literal["both", "lower", "upper"] = "both"
  range: tuple[float, float] = (-np.inf, np.inf)

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """Fit both low and high quantile predictors."""
    if self.side in ["both", "lower"]:
      self.q_lo_predictor.fit(X, y)
    if self.side in ["both", "upper"]:
      self.q_hi_predictor.fit(X, y)

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    """Predict quantile intervals with uncertainty information."""
    aux_outputs = {}

    if self.side == "both":
      lo_preds = self.q_lo_predictor(X)
      hi_preds = self.q_hi_predictor(X)
      preds = np.column_stack([lo_preds, hi_preds])
    elif self.side == "upper":
      hi_preds = self.q_hi_predictor(X)
      preds = np.column_stack([np.full(X.shape[0], self.range[0]), hi_preds])
    elif self.side == "lower":
      lo_preds = self.q_lo_predictor(X)
      preds = np.column_stack([lo_preds, np.full(X.shape[0], self.range[1])])

    return preds, aux_outputs


class OraclePredictor(Predictor):
  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    pass

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    if X.ndim > 1:
      preds = X.reshape(-1)
    else:
      preds = X
    return preds, {}


@dataclasses.dataclass
class PassThroughPredictor(Predictor):
  flatten: bool = False

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    pass

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    if self.flatten and X.ndim == 2:
      preds = X.flatten()
    else:
      preds = X
    return preds, {}


class MLP(torch.nn.Module):
  """A multi-layer perceptron model with configurable architecture."""

  def __init__(self, input_dims: int, hidden_dims: int, num_layers: int):
    super().__init__()
    self.config = {
      "input_dims": input_dims,
      "hidden_dims": hidden_dims,
      "num_layers": num_layers,
    }

    layers = []
    layers.append(torch.nn.Linear(input_dims, hidden_dims))
    layers.append(torch.nn.ReLU())

    for _ in range(num_layers - 1):
      layers.append(torch.nn.Linear(hidden_dims, hidden_dims))
      layers.append(torch.nn.ReLU())

    layers.append(torch.nn.Linear(hidden_dims, 1))
    layers.append(torch.nn.Sigmoid())

    self.model = torch.nn.Sequential(*layers)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the network."""
    return self.model(x)

  def save_checkpoint(self, path: str) -> None:
    """Save model checkpoint to the specified path."""
    checkpoint = {"model_state_dict": self.state_dict(), "config": self.config}
    with upath.UPath(path).open("wb") as f:
      torch.save(checkpoint, f)

  @classmethod
  def load_checkpoint(cls, path: str | upath.UPath) -> "MLP":
    """Load model checkpoint from the specified path."""
    path = upath.UPath(path)
    if not path.exists():
      raise FileNotFoundError(f"No checkpoint found at {path.as_posix()}")
    with path.open("rb") as f:
      checkpoint = torch.load(f, weights_only=False)
      model = cls(**checkpoint["config"])
      model.load_state_dict(checkpoint["model_state_dict"])

      logger.debug(f"Loaded MLP from {path.as_posix()}")
      return model


def get_gpu_with_most_free_memory() -> Optional[int]:
  """Returns the device ID of the GPU with the most free memory."""
  memory_info = get_gpu_memory_info()
  return memory_info[0][0] if memory_info else None


def get_gpu_memory_info() -> Optional[list[tuple[int, int, int]]]:
  """Get memory information for all available GPUs."""
  if not torch.cuda.is_available():
    return None

  memory_info = []
  for i in range(torch.cuda.device_count()):
    torch.cuda.set_device(i)
    total_memory = torch.cuda.get_device_properties(i).total_memory
    free_memory = torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i)
    memory_info.append((i, total_memory, free_memory))

  return sorted(memory_info, key=lambda x: x[2], reverse=True)


class MlpPredictor(Predictor):
  """A predictor that uses a PyTorch MLP model."""

  def __init__(self, checkpoint: str):
    """Initialize the predictor with a pre-trained model."""
    self.device = get_gpu_with_most_free_memory()
    self.model = MLP.load_checkpoint(checkpoint).to(self.device)

  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
    """No-op as model is pre-trained."""
    pass

  def predict_with_aux(
    self,
    X: np.ndarray,
  ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    """Make predictions with uncertainty information."""
    with torch.inference_mode():
      X_tensor = torch.tensor(X, dtype=torch.float32, device=self.device)
      preds = self.model(X_tensor)
      return preds.cpu().numpy().squeeze(), {}
