import methods.histogram_binning
import methods.platt_scaling
import methods.isotonic_regression
import methods.temp_scaling
import methods.mlp
import methods.group_calibration
import methods.min_var_p_grouping


def get_calibrate_fn(method_config):
    if method_config.name == "temperature_scaling":
        return methods.temp_scaling.calibrate
    elif method_config.name == "histogram_binning":
        return methods.histogram_binning.calibrate
    elif method_config.name == "isotonic_regression":
        return methods.isotonic_regression.calibrate
    elif method_config.name == "platt_scaling":
        return methods.platt_scaling.calibrate
    elif method_config.name == 'mlp':
        return methods.mlp.calibrate
    else:
        raise ValueError("config_name {} not found".format(method_config))


def calibrate(method_config,
              val_data,
              test_train_data,
              test_test_data,
              seed,
              cfg):
    if method_config.name == "none":
        return {'logits': test_test_data["logits"]}
    elif method_config.name == "histogram_binning":
        return methods.histogram_binning.calibrate(
            train_logits=test_train_data['logits'],
            train_labels=test_train_data['labels'],
            test_logits=test_test_data['logits']
        )
    elif method_config.name == "platt_scaling":
        return methods.platt_scaling.calibrate(
            train_logits=test_train_data['logits'],
            train_labels=test_train_data['labels'],
            test_logits=test_test_data['logits'],
        )
    elif method_config.name == "isotonic_regression":
        return methods.isotonic_regression.calibrate(
            train_logits=test_train_data['logits'],
            train_labels=test_train_data['labels'],
            test_logits=test_test_data['logits'],
        )
    elif method_config.name == "temperature_scaling":
        return methods.temp_scaling.calibrate(
            train_logits=test_train_data['logits'],
            train_labels=test_train_data['labels'],
            test_logits=test_test_data['logits'],
        )
    elif method_config.name == 'mlp':
        return methods.mlp.calibrate(
            train_logits=test_train_data['logits'],
            train_labels=test_train_data['labels'],
            test_logits=test_test_data['logits'],
        )
    elif "PCE" in method_config.name:
        return methods.group_calibration.calibrate(
            val_features=val_data["features"],
            val_logits=val_data["logits"],
            val_labels=val_data["labels"],
            test_train_features=test_train_data["features"],
            test_train_logits=test_train_data["logits"],
            test_train_labels=test_train_data["labels"],
            test_test_features=test_test_data["features"],
            test_test_logits=test_test_data["logits"],
            base_calibrate_fn=get_calibrate_fn(
                method_config=method_config.base_calibrator),
            method_config=method_config,
            seed=seed,
            cfg=cfg
        )
    elif "MinVarCE" in method_config.name:
        return methods.min_var_p_grouping.calibrate(
            val_features=val_data["features"],
            val_logits=val_data["logits"],
            val_labels=val_data["labels"],
            test_train_features=test_train_data["features"],
            test_train_logits=test_train_data["logits"],
            test_train_labels=test_train_data["labels"],
            test_test_features=test_test_data["features"],
            test_test_logits=test_test_data["logits"],
            base_calibrate_fn=get_calibrate_fn(
                method_config=method_config.base_calibrator),
            method_config=method_config,
            seed=seed,
            cfg=cfg
        )
    else:
        raise
