from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Type

import numpy as np

from ..data import Dataset
from ..predictors import Predictor


@dataclass
class CalibratorConfig:
    type: str

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "CalibratorConfig":
        return cls(**config_dict)


class Calibrator(ABC):
    def __init__(
        self, predictor: Predictor, config: CalibratorConfig, dataset: "Dataset"
    ):
        self.config = config
        self.dataset = dataset

    @abstractmethod
    def predict(self, X: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    @abstractmethod
    def fit(self, dataset: "Dataset") -> None:
        raise NotImplementedError


class CalibratorRegistry:
    _calibrators: Dict[str, Type[Calibrator]] = {}
    _configs: Dict[str, Type[CalibratorConfig]] = {}

    @classmethod
    def register(
        cls,
        calibrator_type: str,
        config_class: Type[CalibratorConfig],
        calibrator_class: Type[Calibrator],
    ):
        cls._configs[calibrator_type] = config_class
        cls._calibrators[calibrator_type] = calibrator_class

    @classmethod
    def get(cls, calibrator_type: str) -> Type[Calibrator]:
        if calibrator_type not in cls._calibrators:
            raise ValueError(
                f"Calibrator type '{calibrator_type}' not found. "
                f"Available types: {list(cls._calibrators.keys())}"
            )
        return cls._calibrators[calibrator_type]

    @classmethod
    def get_config(
        cls, calibrator_type: str, config_dict: Dict[str, Any]
    ) -> CalibratorConfig:
        if calibrator_type not in cls._configs:
            raise ValueError(
                f"Calibrator config type '{calibrator_type}' not found. "
                f"Available types: {list(cls._configs.keys())}"
            )
        config_class = cls._configs[calibrator_type]
        return config_class.from_dict(config_dict)

    @classmethod
    def get_available_types(cls) -> list[str]:
        return list(cls._calibrators.keys())


def register_calibrator(
    calibrator_type: str,
    config_class: Type[CalibratorConfig],
    calibrator_class: Type[Calibrator],
):
    CalibratorRegistry.register(calibrator_type, config_class, calibrator_class)


__all__ = [
    "CalibratorConfig",
    "Calibrator",
    "CalibratorRegistry",
    "register_calibrator",
]


def _register():
    from . import grid_boost


_register()
