import dataclasses
from generative_prediction_sets.core import Predictor, CalibrationContext
import numpy as np
import pytest


class BaseRegressor(Predictor):
  """A base regression model that returns predictions and confidence scores."""

  def fit(self, X, y):
    pass

  def predict_with_aux(self, X: np.ndarray) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    # Simulate predictions and confidence scores
    predictions = X  # For testing, just return input
    confidence_scores = np.ones_like(X) * 0.8  # Simulate 80% confidence
    feature_importance = np.ones_like(X) * 0.5  # Simulate feature importance scores

    return predictions, {
      "confidence_scores": confidence_scores,
      "feature_importance": feature_importance,
    }


@pytest.fixture
def base_regressor():
  return BaseRegressor()


@dataclasses.dataclass
class EnsembleRegressor(Predictor):
  """An ensemble regressor that combines predictions from two base models."""

  model1: BaseRegressor = dataclasses.field(default_factory=BaseRegressor)
  model2: BaseRegressor = dataclasses.field(default_factory=BaseRegressor)

  def fit(self, X, y):
    pass

  def predict_with_aux(self, X: np.ndarray) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    # Get predictions from base models
    preds1, aux1 = self.model1.predict_with_aux(X)
    preds2, aux2 = self.model2.predict_with_aux(X)

    # Combine predictions with equal weights
    ensemble_preds = (preds1 + preds2) / 2

    # Calculate ensemble-level auxiliary values
    ensemble_confidence = np.full_like(X, 0.9)  # Higher confidence from ensemble
    model_weights = np.full_like(X, 0.5)  # Equal weights for both models

    # Combine all auxiliary values
    aux = {
      "ensemble_confidence": ensemble_confidence,
      "model_weights": model_weights,
      "model1_confidence": aux1["confidence_scores"],
      "model1_importance": aux1["feature_importance"],
      "model2_confidence": aux2["confidence_scores"],
      "model2_importance": aux2["feature_importance"],
    }

    return ensemble_preds, aux


def test_base_regressor_aux(base_regressor):
  X = np.array([1.0, 2.0, 3.0])

  with CalibrationContext.set(alpha=0.1):
    preds = base_regressor(X)

  # Check predictions
  np.testing.assert_array_equal(preds, X)

  # Get auxiliary values directly
  _, aux = base_regressor.predict_with_aux(X)

  # Check aux structure
  assert "confidence_scores" in aux
  assert "feature_importance" in aux

  np.testing.assert_array_equal(aux["confidence_scores"], np.full_like(X, 0.8))
  np.testing.assert_array_equal(aux["feature_importance"], np.full_like(X, 0.5))


def test_ensemble_regressor_aux():
  X = np.array([1.0, 2.0, 3.0])
  predictor = EnsembleRegressor()

  with CalibrationContext.set(alpha=0.1):
    preds = predictor(X)

  # Check predictions
  np.testing.assert_array_equal(preds, X)  # Should be average of X and X

  # Get auxiliary values directly
  _, aux = predictor.predict_with_aux(X)

  # Check ensemble's aux values
  assert "ensemble_confidence" in aux
  assert "model_weights" in aux
  np.testing.assert_array_equal(aux["ensemble_confidence"], np.full_like(X, 0.9))
  np.testing.assert_array_equal(aux["model_weights"], np.full_like(X, 0.5))

  # Check model1's aux values
  assert "model1_confidence" in aux
  assert "model1_importance" in aux
  np.testing.assert_array_equal(aux["model1_confidence"], np.full_like(X, 0.8))
  np.testing.assert_array_equal(aux["model1_importance"], np.full_like(X, 0.5))

  # Check model2's aux values
  assert "model2_confidence" in aux
  assert "model2_importance" in aux
  np.testing.assert_array_equal(aux["model2_confidence"], np.full_like(X, 0.8))
  np.testing.assert_array_equal(aux["model2_importance"], np.full_like(X, 0.5))


def test_predictor_call_without_context():
  """Test that __call__ works without an explicit context."""
  X = np.array([1.0, 2.0, 3.0])
  predictor = EnsembleRegressor()

  # Should use default context
  preds = predictor(X)
  np.testing.assert_array_equal(preds, X)  # Should be average of X and X


def test_aux_isolation():
  """Test that aux values are independent between different predictors."""
  X = np.array([1.0, 2.0, 3.0])
  predictor1 = EnsembleRegressor()
  predictor2 = EnsembleRegressor()

  # Get aux values from both predictors
  _, aux1 = predictor1.predict_with_aux(X)
  _, aux2 = predictor2.predict_with_aux(X)

  # Each predictor should have its own aux values
  assert "ensemble_confidence" in aux1
  assert "ensemble_confidence" in aux2

  # Values should be the same since predictors are identical
  np.testing.assert_array_equal(
    aux1["ensemble_confidence"], aux2["ensemble_confidence"]
  )
