import pickle

import torch

from src.experiments.predictive_shapley_values.PredictiveExplanationExperiment import (
    PredictiveExplanationExperiment,
    loss_function
)

num_cv = 10
explanation_types = ["gp", "rf", "deep"]


def main():
    results = []
    for cv in range(num_cv):
        print(f"at cv {cv}")
        experiments = [PredictiveExplanationExperiment(target_explanation_type=explanation_type) for explanation_type in
                       explanation_types]
        cv_result = []
        for i, experiment in enumerate(experiments):
            print(f"doing experiment type {explanation_types[i]}")
            experiment.run()

            with torch.no_grad():
                gp_pred = torch.tensor(experiment.explanation_prediction_gp)
                rf_pred = experiment.explanation_prediction_rf
                deep_pred = experiment.explanation_prediction_deep
                explanation_target = torch.tensor(experiment.target_test.reshape(-1, 10))

                # collect result
                gp_pred_error = loss_function(gp_pred, explanation_target).numpy()
                rf_pred_error = loss_function(rf_pred, explanation_target).numpy()
                deep_pred_error = loss_function(deep_pred, explanation_target).numpy()

                cv_result.append([explanation_types[i], gp_pred_error, "gp_pred", cv])
                cv_result.append([explanation_types[i], rf_pred_error, "rf_pred", cv])
                cv_result.append([explanation_types[i], deep_pred_error, "deep_pred", cv])

        with open(f"data/predictive_experiment/diabetes_cv_{cv}.pkl", "wb") as f:
            pickle.dump(cv_result, f)

        results.append(cv_result)

    return results


if __name__ == '__main__':
    results = main()
    with open("data/predictive_experiment/diabetes_full_result.pkl", "wb") as f:
        pickle.dump(results, f)
