import os
import pickle
import config.config as c
from utils.prepare_datasets import (
    prepare_seaborn_dataset,
    prepare_kaggle_dataset,
    prepare_open_ml_dataset,
    prepare_sklearn_dataset,
)
from data.dataset import Dataset
from experiments.experiment import Experiment


def run_experiment_for_dataset(X_train, X_test, y_train, y_test, random_state=c.DEFAULT_RANDOM_STATE):
    trainset = Dataset(X_train, y_train, random_state=random_state)
    testset = Dataset(X_test, y_test, random_state=random_state)

    fixed_params = c.FIXED_PARAMS.copy()
    fixed_params.update({
        "trainset": trainset,
        "testset": testset
    })
    changing_param = c.CHANGING_PARAMS

    experiment = Experiment(fixed_params, changing_param, n_runs=c.NB_PERMUTATION_SAMPLING)
    experiment.run_experiment()

    return experiment.values, experiment.marg_contrib_dict


if __name__ == "__main__":
    all_values = {}
    all_marg_contrib = {}
    random_state = c.DEFAULT_RANDOM_STATE
    train_size = c.TRAIN_SIZE
    test_size = c.TEST_SIZE

    sklearn_datasets_info = ["breast_cancer"]
    for name in sklearn_datasets_info:
        print(f"Computing values for sklearn dataset {name}")
        X_train, X_test, y_train, y_test = prepare_sklearn_dataset(name, random_state, train_size, test_size)
        values, marg_contrib = run_experiment_for_dataset(X_train, X_test, y_train, y_test)
        all_values[name] = values
        all_marg_contrib[name] = marg_contrib

    seaborn_datasets_info = ["titanic"]
    for name in seaborn_datasets_info:
        print(f"Computing values for seaborn dataset {name}")
        X_train, X_test, y_train, y_test = prepare_seaborn_dataset(name, random_state, train_size,  test_size)
        values, marg_contrib = run_experiment_for_dataset(X_train, X_test, y_train, y_test)
        all_values[name] = values
        all_marg_contrib[name] = marg_contrib

    kaggle_datasets_info = {
        "credit": {"file_path": "data/kaggle/fraud.csv", "target": "default.payment.next.month"},
        "heart": {"file_path": "data/kaggle/heart.csv", "target": "target"},
    }
    for name, info in kaggle_datasets_info.items():
        print(f"Computing values for kaggle dataset {name}")
        X_train, X_test, y_train, y_test = prepare_kaggle_dataset(info["file_path"], info["target"], random_state, train_size, test_size)
        values, marg_contrib = run_experiment_for_dataset(X_train, X_test, y_train, y_test)
        all_values[name] = values
        all_marg_contrib[name] = marg_contrib

    openml_datasets_info = {
        "wind": {"id": 847, "target": "binaryClass"},
        "cpu": {"id": 761, "target": "binaryClass"},
        "2dplanes": {"id": 727, "target": "binaryClass"},
        "pol": {"id": 722, "target": "binaryClass"},
    }
    for name, info in openml_datasets_info.items():
        print(f"Computing values for OpenML dataset {name}")
        X_train, X_test, y_train, y_test = prepare_open_ml_dataset(info["id"], info["target"], random_state, train_size, test_size)
        values, marg_contrib = run_experiment_for_dataset(X_train, X_test, y_train, y_test)
        all_values[name] = values
        all_marg_contrib[name] = marg_contrib

    os.makedirs("results", exist_ok=True)
    with open("results/all_marg_contrib.pkl", "wb") as f:
        pickle.dump(all_marg_contrib, f)

    with open("results/all_values.pkl", "wb") as f:
        pickle.dump(all_values, f)
