from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, Tuple, Type

import numpy as np


@dataclass
class DatasetConfig:
    type: str
    test_size: int
    predictor_size: int

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetConfig":
        return cls(**config_dict)


class Dataset(ABC):
    def __init__(self, config: DatasetConfig):
        self.config = config

    def load_test(self) -> Tuple[np.ndarray, np.ndarray]:
        raise NotImplementedError

    def load_predictor(self) -> Tuple[np.ndarray, np.ndarray]:
        raise NotImplementedError

    def load_calibrator(self) -> Tuple[np.ndarray, np.ndarray]:
        raise NotImplementedError

    def decision_function(self, y_pred: np.ndarray) -> np.ndarray:
        """Default decision function: top-k one-hot encoding.

        Args:
            y_pred: Predicted values.

        Returns:
            One-hot encoding of top-k values.
        """
        raise NotImplementedError

    @classmethod
    def scale(cls) -> float:
        raise NotImplementedError("Scale method not implemented")


class DatasetRegistry:
    _datasets: Dict[str, Type[Dataset]] = {}
    _configs: Dict[str, Type[DatasetConfig]] = {}

    @classmethod
    def register(
        cls,
        dataset_type: str,
        config_class: Type[DatasetConfig],
        dataset_class: Type[Dataset],
    ):
        cls._configs[dataset_type] = config_class
        cls._datasets[dataset_type] = dataset_class

    @classmethod
    def get(cls, dataset_type: str) -> Type[Dataset]:
        if dataset_type not in cls._datasets:
            raise ValueError(
                f"Dataset type '{dataset_type}' not found. "
                f"Available types: {list(cls._datasets.keys())}"
            )
        return cls._datasets[dataset_type]

    @classmethod
    def get_config(
        cls, dataset_type: str, config_dict: Dict[str, Any]
    ) -> DatasetConfig:
        if dataset_type not in cls._configs:
            raise ValueError(
                f"Dataset config type '{dataset_type}' not found. "
                f"Available types: {list(cls._configs.keys())}"
            )
        config_class = cls._configs[dataset_type]
        return config_class.from_dict(config_dict)

    @classmethod
    def get_available_types(cls) -> list[str]:
        return list(cls._datasets.keys())


def register_dataset(
    dataset_type: str,
    config_class: Type[DatasetConfig],
    dataset_class: Type[Dataset],
):
    DatasetRegistry.register(dataset_type, config_class, dataset_class)


def _register():
    from .synthetic import simple_synthetic, synthetic_graph


_register()


__all__ = [
    "DatasetConfig",
    "Dataset",
    "DatasetRegistry",
    "register_dataset",
]
