from learning.model.base_models import BaseModel


class LearningSetting:
    def __init__(self, model, optimizer, criterion):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.validate_compatibility()

    def validate_compatibility(self):
        if isinstance(self.model, BaseModel):
            if self.model.is_classifier:
                if self.criterion.criterion_name not in ["BCE", "CrossEntropy"]:
                    raise ValueError(
                        "Criterion must be a classification loss function for classifiers."
                    )
            else:
                if self.criterion.criterion_name not in ["MSE", "MAE"]:
                    raise ValueError(
                        "Criterion must be a regression loss function for regressors."
                    )

        print("Compatibility checks passed successfully.")