import itertools

from learning.criterion import Criterion
from learning.model.classifiers import (
    LogisticRegressionModel,
    MLPClassifier,
)
from learning.model.regressors import (
    LinearRegressionModel,
    MLPRegressor,
    PolynomialRegressionModel,
)
from learning.optimizer import Optimizer
from utility_func.utility import Utility

MODEL_MAPPING = {
    "logreg": LogisticRegressionModel,
    "mlp_class": MLPClassifier,
    "linreg": LinearRegressionModel,
    "polyreg": PolynomialRegressionModel,
    "mlp_reg": MLPRegressor,
}


class ExperimentConfig:
    def __init__(self, fixed_params=None, changing_params=None):
        self.fixed_params = fixed_params or {}
        self.changing_params = changing_params or {}
        self._instance_cache = {}

    def generate_param_combinations(self):
        keys, values = (
            zip(*self.changing_params.items()) if self.changing_params else ([], [])
        )
        for combination in itertools.product(*values):
            changing_combination = dict(zip(keys, combination))
            yield {**self.fixed_params, **changing_combination}

    def get_configurations(self):
        configurations = {}
        for i, param_combination in enumerate(
            self.generate_param_combinations(), start=1
        ):
            config_name = i
            configurations[config_name] = self.build_configuration(param_combination)
        return configurations

    def build_configuration(self, param_combination):
        criterion_key = (
            param_combination.get("criterion"),
            frozenset(param_combination.get("regularization", {}).items()),
        )
        if criterion_key not in self._instance_cache:
            self._instance_cache[criterion_key] = Criterion(
                criterion_name=param_combination.get("criterion")
                or self.fixed_params.get("criterion"),
                regularization=param_combination.get("regularization")
                or self.fixed_params.get("regularization", {}),
            )
        criterion_instance = self._instance_cache[criterion_key]

        optimizer_name = param_combination.get("optimizer") or self.fixed_params.get(
            "optimizer"
        )
        optimizer_base_kwargs = {
            "lr": param_combination.get("lr") or self.fixed_params.get("lr"),
            "weight_init": param_combination.get("weight_init")
            or self.fixed_params.get("weight_init", "normal"),
            "batch_size": param_combination.get("batch_size")
            or self.fixed_params.get("batch_size", None),
            "random_state": param_combination.get("random_state")
            or self.fixed_params.get("random_state", 42),
        }

        optimizer_kwargs = param_combination.get("optimizer_kwargs", {})
        optimizer_key = (
            optimizer_name,
            frozenset(optimizer_base_kwargs.items()),
            frozenset(optimizer_kwargs.items()),
        )

        if optimizer_key not in self._instance_cache:
            self._instance_cache[optimizer_key] = Optimizer(
                optimizer_name=optimizer_name,
                **optimizer_base_kwargs,
                **optimizer_kwargs,
            )
        optimizer_instance = self._instance_cache[optimizer_key]

        utility_key = (
            param_combination.get("utility", {}).get("utility_name"),
            param_combination.get("utility", {}).get("threshold", 0.5),
        )
        if utility_key not in self._instance_cache:
            self._instance_cache[utility_key] = Utility(
                utility_name=param_combination.get("utility", {}).get("utility_name")
                or self.fixed_params.get("utility", {}).get("utility_name"),
                threshold=param_combination.get("utility", {}).get("threshold")
                or self.fixed_params.get("utility", {}).get("threshold", 0.5),
            )
        utility_instance = self._instance_cache[utility_key]

        model_name = param_combination.get("model") or self.fixed_params.get("model")
        model_params = param_combination.get("model_kwargs", {})
        model_class = MODEL_MAPPING.get(model_name)
        if model_class is None:
            raise ValueError(f"Model '{model_name}' not found in MODEL_MAPPING.")

        trainset = param_combination.get("trainset") or self.fixed_params.get(
            "trainset"
        )
        testset = param_combination.get("testset") or self.fixed_params.get("testset")

        return {
            "learning_setting": {
                "model": model_class,
                "model_kwargs": model_params,
                "optimizer": optimizer_instance,
                "criterion": criterion_instance,
            },
            "trainset": trainset,
            "testset": testset,
            "utility": utility_instance,
        }
