from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Type

import numpy as np

from ..data import Dataset


@dataclass
class PredictorConfig:
    type: str

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "PredictorConfig":
        return cls(**config_dict)


class Predictor(ABC):
    def __init__(self, config: PredictorConfig):
        self.config = config

    @abstractmethod
    def predict(self, X: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    @abstractmethod
    def fit(self, dataset: "Dataset") -> None:
        raise NotImplementedError


class PredictorRegistry:
    _predictors: Dict[str, Type[Predictor]] = {}
    _configs: Dict[str, Type[PredictorConfig]] = {}

    @classmethod
    def register(
        cls,
        predictor_type: str,
        config_class: Type[PredictorConfig],
        predictor_class: Type[Predictor],
    ):
        cls._configs[predictor_type] = config_class
        cls._predictors[predictor_type] = predictor_class

    @classmethod
    def get(cls, predictor_type: str) -> Type[Predictor]:
        if predictor_type not in cls._predictors:
            raise ValueError(
                f"Predictor type '{predictor_type}' not found. "
                f"Available types: {list(cls._predictors.keys())}"
            )
        return cls._predictors[predictor_type]

    @classmethod
    def get_config(
        cls, predictor_type: str, config_dict: Dict[str, Any]
    ) -> PredictorConfig:
        if predictor_type not in cls._configs:
            raise ValueError(
                f"Predictor config type '{predictor_type}' not found. "
                f"Available types: {list(cls._configs.keys())}"
            )
        config_class = cls._configs[predictor_type]
        return config_class.from_dict(config_dict)

    @classmethod
    def get_available_types(cls) -> list[str]:
        return list(cls._predictors.keys())


def register_predictor(
    predictor_type: str,
    config_class: Type[PredictorConfig],
    predictor_class: Type[Predictor],
):
    PredictorRegistry.register(predictor_type, config_class, predictor_class)


__all__ = [
    "PredictorConfig",
    "Predictor",
    "PredictorRegistry",
    "register_predictor",
]


def _register():
    from . import simple_net


_register()
