from dataclasses import dataclass

from ..calibrators import CalibratorRegistry
from ..data import DatasetRegistry
from ..predictors import PredictorRegistry
from ..utils import get_logger
from . import ExperimentConfig, register_experiment

logger = get_logger(__name__)


@dataclass
class DummyConfig(ExperimentConfig):
    pass


def run_dummy_experiment(config: ExperimentConfig):
    logger.info("Running dummy experiment")
    logger.info(config)

    predictor_type = config.predictor_config["type"]
    predictor_config = PredictorRegistry.get_config(
        predictor_type, config.predictor_config
    )
    predictor_class = PredictorRegistry.get(predictor_type)
    predictor = predictor_class(predictor_config)

    dataset_type = config.dataset_config["type"]
    dataset_config = DatasetRegistry.get_config(dataset_type, config.dataset_config)
    dataset_class = DatasetRegistry.get(dataset_type)
    dataset = dataset_class(dataset_config)

    # Fit predictor using dataset
    predictor.fit(dataset)

    # Fit calibrator using dataset
    calibrator_type = config.calibrator_config["type"]
    calibrator_config = CalibratorRegistry.get_config(
        calibrator_type, config.calibrator_config
    )
    calibrator_class = CalibratorRegistry.get(calibrator_type)
    calibrator = calibrator_class(predictor, calibrator_config, dataset)
    calibrator.fit(dataset)


register_experiment("dummy", DummyConfig, run_dummy_experiment)
