import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch

from code_demeanor.utils.notebook_visualisations import (
    get_layer_wise_violin,
    get_attention_layer_heatmap,
    get_lda_visualisation,
    get_pca_visualisation,
    get_stats_and_labels,
    normalise_stats,
    train_mlp_and_evaluate,
    get_confusion_matrix_plot,
)

MODEL_NAME_1 = "meta-llama_Llama-3.2-3B-Instruct"
MODEL_NAME_2 = "meta-llama_Llama-3.1-8B-Instruct"
MODEL_NAME_3 = "Qwen_Qwen2.5-14B-Instruct"
CONFIG_SETTINGS = "_layers_0_14_27_heads_0_12_23_seqLen_None_useLayerIntervention_False_useAttentionIntervention_False_useMicrosaccadeIntervention_True"
BASELINE_SETTINGS_1 = "_layers_0_14_27_heads_0_12_23_seqLen_None_useLayerIntervention_True_useAttentionIntervention_False_useMicrosaccadeIntervention_False"
BASELINE_SETTINGS_2 = "_layers_0_18_35_heads_0_12_23_seqLen_None_useLayerIntervention_True_useAttentionIntervention_False_useMicrosaccadeIntervention_False"
BASELINE_SETTINGS_3 = "_layers_0_14_27_heads_0_12_23_seqLen_None_useLayerIntervention_True_useAttentionIntervention_False_useMicrosaccadeIntervention_False"

GAUSSIAN_BASELINE_SETTINGS_1 = "_layers_0_14_27_heads_0_12_23_seqLen_None_useLayerIntervention_False_useAttentionIntervention_False_useMicrosaccadeIntervention_True_gaussian"
RANDOM_BASELINE_SETTINGS_1 = "_layers_0_14_27_heads_0_12_23_seqLen_None_useLayerIntervention_False_useAttentionIntervention_False_useMicrosaccadeIntervention_True_random"
GAUSSIAN_BASELINE_SETTINGS_2 = "_layers_0_18_35_heads_0_12_23_seqLen_None_useLayerIntervention_False_useAttentionIntervention_False_useMicrosaccadeIntervention_True_gaussian"
RANDOM_BASELINE_SETTINGS_2 = "_layers_0_18_35_heads_0_12_23_seqLen_None_useLayerIntervention_False_useAttentionIntervention_False_useMicrosaccadeIntervention_True_random"

MODEL_NAME_1_N_HEADS = 28
MODEL_NAME_2_N_HEADS = 36
MODEL_NAME_1_N_LAYERS = 24
MODEL_NAME_2_N_LAYERS = 18

MODEL_PARAMETERS = {
    MODEL_NAME_1: (MODEL_NAME_1_N_HEADS, MODEL_NAME_1_N_LAYERS),
    MODEL_NAME_2: (MODEL_NAME_2_N_HEADS, MODEL_NAME_2_N_LAYERS),
}



@dataclass
class ExperimentFiles:
    dataset_name: str
    base_dir: str
    output_category: str
    model_name_1: str = MODEL_NAME_1
    model_name_2: str = MODEL_NAME_2
    model_name_3: str = MODEL_NAME_3
    config_settings: str = CONFIG_SETTINGS
    baseline_settings_1: str = BASELINE_SETTINGS_1
    baseline_settings_2: str = BASELINE_SETTINGS_2
    baseline_settings_3: str = BASELINE_SETTINGS_3
    gaussian_baseline_1: str = GAUSSIAN_BASELINE_SETTINGS_1
    random_baseline_1: str = RANDOM_BASELINE_SETTINGS_1
    gaussian_baseline_2: str = GAUSSIAN_BASELINE_SETTINGS_2
    random_baseline_2: str = RANDOM_BASELINE_SETTINGS_2
    gaussian_baseline_3: str = GAUSSIAN_BASELINE_SETTINGS_1 # TODO still
    random_baseline_3: str = RANDOM_BASELINE_SETTINGS_1

    def base_output_dir(self, output_category, model_name) -> str:
        os.makedirs(output_category, exist_ok=True)
        full_output_path = f"{output_category}/{self.dataset_name}_{model_name}"
        return full_output_path

    def model_name_files(self, model_name, baseline: Optional[str] = None, gaussian_baseline: Optional[str] = None,
                         random_baseline: Optional[str] = None
                         ) -> Dict[str, str]:
        if baseline:
            config = baseline
        elif gaussian_baseline:
            config = gaussian_baseline
        elif random_baseline:
            config = random_baseline
        else:
            config = self.config_settings

        return (
            self.base_dir + "/" + "train_labels_scan_" + model_name + config + ".jsonl",
            self.base_dir + "/" + "train_data_scan_" + model_name + config + ".jsonl",
            self.base_dir + "/" + "test_labels_scan_" + model_name + config + ".jsonl",
            self.base_dir + "/" + "test_data_scan_" + model_name + config + ".jsonl",
        )


def write_results_to_jsonl(data: Dict[str, Any]):
    if not os.path.exists("summary_results.jsonl"):
        with open("summary_results.jsonl", "w") as f:
            f.write("")

    with open(f"summary_results.jsonl", "a") as f:
        f.write(json.dumps(data) + "\n")


def train_and_evaluate(
    train_stats: List[List[float]],
    train_labels: List[int],
    test_stats: List[List[float]],
    test_labels: List[int],
    experiment_files: ExperimentFiles,
    model_name: str,
    baseline: bool = False,
    combined: bool = False,
    gaussian_baseline: bool = False,
    random_baseline: bool = False,
):
    train_norm, test_norm = normalise_stats(train_stats, test_stats)
    output_dir = experiment_files.base_output_dir(
        experiment_files.output_category, model_name
    )

    acc, auc, y_true, y_pred = train_mlp_and_evaluate(
        train_norm,
        train_labels,
        test_norm,
        test_labels,
        output_dir,
    )
    model_parameters = MODEL_PARAMETERS.get(model_name)

    # Save to a jsonl file
    if combined:
        model_name_suffix = f"combined"
    elif baseline:
        model_name_suffix = f"baseline"
    elif gaussian_baseline:
        model_name_suffix = f"gaussian_baseline"
    elif random_baseline:
        model_name_suffix = f"random_baseline"
    else:
        model_name_suffix = ""

    model_name = f"{model_name}_{model_name_suffix}"

    data = {
        "model_name": model_name,
        "dataset_name": experiment_files.dataset_name,
        "auc": auc,
        "accuracy": acc,
    }
    write_results_to_jsonl(data)
    print(f"MLP Test Accuracy: {acc:.4f}, AUC: {auc:.4f}")

    if not baseline:
        get_confusion_matrix_plot(y_true, y_pred, output_dir, show_figure=False)
        get_pca_visualisation(train_norm, train_labels, output_dir, show_figure=False)
        get_lda_visualisation(train_norm, train_labels, output_dir, show_figure=False)
        get_layer_wise_violin(train_norm,
                                train_labels,
                                n_heads=model_parameters[0],
                                n_layers=model_parameters[1],
                                base_output_name=output_dir,
                                show_figure=False,
                                save_fig=True,
                                )

        get_attention_layer_heatmap(
            train_norm,
            train_labels,
            n_heads=model_parameters[0],
            n_layers=model_parameters[1],
            base_output_name=output_dir,
            show_figure=False,
            save_fig=True,
        )

    return acc, auc


def model_main(
    model_name: str, experiment_files: ExperimentFiles, baseline: Optional[str] = None, gaussian_baseline: Optional[str] = None, random_baseline: Optional[str] = None
) -> str:

    (
        token_stats_train_labels_path,
        token_stats_train_path,
        token_stats_test_labels_path,
        token_stats_test_path,
    ) = experiment_files.model_name_files(model_name, baseline=baseline, gaussian_baseline=gaussian_baseline, random_baseline=random_baseline)

    train_stats, train_labels, test_stats, test_labels = get_stats_and_labels(
        token_stats_train_labels_path,
        token_stats_test_labels_path,
        token_stats_train_path,
        token_stats_test_path,
    )
    
    
    acc, auc = train_and_evaluate(
        train_stats,
        train_labels,
        test_stats,
        test_labels,
        experiment_files,
        model_name,
        baseline=baseline,
        gaussian_baseline=gaussian_baseline,
        random_baseline=random_baseline,
    )

    return acc, auc


def model_main_combine(
    model_name: str, experiment_files: ExperimentFiles, baseline: bool = False
) -> str:
    (
        token_stats_train_labels_path,
        token_stats_train_path,
        token_stats_test_labels_path,
        token_stats_test_path,
    ) = experiment_files.model_name_files(model_name, baseline=baseline)

    train_stats, train_labels, test_stats, test_labels = get_stats_and_labels(
        token_stats_train_labels_path,
        token_stats_test_labels_path,
        token_stats_train_path,
        token_stats_test_path,
    )
    (
        token_stats_train_labels_path_baseline,
        token_stats_train_path_baseline,
        token_stats_test_labels_path_baseline,
        token_stats_test_path_baseline,
    ) = experiment_files.model_name_files(model_name, baseline=True)

    (
        train_stats_baseline,
        train_labels_baseline,
        test_stats_baseline,
        test_labels_baseline,
    ) = get_stats_and_labels(
        token_stats_train_labels_path_baseline,
        token_stats_test_labels_path_baseline,
        token_stats_train_path_baseline,
        token_stats_test_path_baseline,
    )

    # Concatenate the stats and labels
    train_stats_combined = torch.cat((train_stats, train_stats_baseline), dim=1)
    test_stats_combined = torch.cat((test_stats, test_stats_baseline), dim=1)
    train_and_evaluate(
        train_stats_combined,
        train_labels,
        test_stats_combined,
        test_labels,
        experiment_files,
        model_name,
        combined=True,
    )


@dataclass
class ExperimentResults:
    acc1: Optional[float] = None
    auc1: Optional[float] = None
    acc2: Optional[float] = None
    auc2: Optional[float] = None
    acc3: Optional[float] = None
    auc3: Optional[float] = None
    gaussian_acc_1: Optional[float] = None
    gaussian_auc_1: Optional[float] = None
    random_acc_1: Optional[float] = None
    random_auc_1: Optional[float] = None
    gaussian_acc_2: Optional[float] = None
    gaussian_auc_2: Optional[float] = None
    random_acc_2: Optional[float] = None
    random_auc_2: Optional[float] = None
    gaussian_acc_3: Optional[float] = None
    gaussian_auc_3: Optional[float] = None
    random_acc_3: Optional[float] = None
    random_auc_3: Optional[float] = None
    gaussian_delta_perc: Optional[float] = None
    random_delta_perc: Optional[float] = None


def main(experiment_files: ExperimentFiles):
    gaussian_acc_1, gaussian_auc_1 = None, None
    random_acc_1, random_auc_1 = None, None
    gaussian_acc_2, gaussian_auc_2 = None, None
    random_acc_2, random_auc_2 = None, None
    gaussian_acc_3, gaussian_auc_3 = None, None
    random_acc_3, random_auc_3 = None, None
    acc1, auc1 = None, None
    acc2, auc2 = None, None
    acc3, auc3 = None, None
    try:
        gaussian_acc_1, gaussian_auc_1 = model_main(experiment_files.model_name_1, experiment_files=experiment_files, gaussian_baseline=experiment_files.gaussian_baseline_1)
    except Exception as e:
        print(f"Error processing model {experiment_files.model_name_1}: {e}")

    try:
        random_acc_1, random_auc_1 = model_main(experiment_files.model_name_1, experiment_files=experiment_files, random_baseline=experiment_files.random_baseline_1)
    except Exception as e:
        print(f"Error processing model {experiment_files.model_name_1}: {e}")

    # try:
    #     gaussian_acc_2, gaussian_auc_2 = model_main(experiment_files.model_name_2, experiment_files=experiment_files, gaussian_baseline=experiment_files.gaussian_baseline_2)
    # except Exception as e:
    #     print(f"Error processing model {experiment_files.model_name_2}: {e}")

    # try:
    #     random_acc_2, random_auc_2 = model_main(experiment_files.model_name_2, experiment_files=experiment_files, random_baseline=experiment_files.random_baseline_2)
    # except Exception as e:
    #     print(f"Error processing model {experiment_files.model_name_1}: {e}")

    # try:
    #     gaussian_acc_3, gaussian_auc_3 = model_main(experiment_files.model_name_3, experiment_files=experiment_files, gaussian_baseline=experiment_files.gaussian_baseline_3)
    # except Exception as e:
    #     print(f"Error processing model {experiment_files.model_name_3}: {e}")

    # try:
    #     random_acc_3, random_auc_3 = model_main(experiment_files.model_name_3, experiment_files=experiment_files, random_baseline=experiment_files.random_baseline_3)
    # except Exception as e:
    #     print(f"Error processing model {experiment_files.model_name_3}: {e}")

    try:
        acc1, auc1 = model_main(experiment_files.model_name_1, experiment_files=experiment_files)
    except Exception as e:
        print(f"Error processing model {experiment_files.model_name_1}: {e}")

    try:
        acc2, auc2 = model_main(experiment_files.model_name_2, experiment_files=experiment_files)
    except Exception as e:
        print(f"Error processing model {experiment_files.model_name_2}: {e}")

    try:
        acc3, auc3 = model_main(experiment_files.model_name_3, experiment_files=experiment_files)
    except Exception as e:
        print(f"Error processing model {experiment_files.model_name_3}: {e}")

    try:
        baseline_acc_1, baseline_auc_1 = model_main(
            experiment_files.model_name_1,
            experiment_files=experiment_files,
            baseline=experiment_files.baseline_settings_1,
        )
    except Exception as e:
        print(
            f"Error processing baseline model {experiment_files.model_name_1}: {e}"
        )
    try:
        baseline_acc_2, baseline_auc_2 = model_main(
            experiment_files.model_name_2,
            experiment_files=experiment_files,
            baseline=experiment_files.baseline_settings_2,
        )
    except Exception as e:
        print(
            f"Error processing baseline model {experiment_files.model_name_2}: {e}"
       )
    try:
        baseline_acc_3, baseline_auc_3 = model_main(
            experiment_files.model_name_3,
            experiment_files=experiment_files,
            baseline=experiment_files.baseline_settings_3,
        )
    except Exception as e:
        print(
            f"Error processing baseline model {experiment_files.model_name_3}: {e}"
       )
        


    gaussian_delta_perc, random_delta_perc = None, None
    if gaussian_acc_1:
        gaussian_delta_perc = ((acc1 - gaussian_acc_1) / gaussian_acc_1) * 100
        print(f"Delta accuracy between {experiment_files.model_name_1} and its gaussian baseline: {gaussian_delta_perc:.4f}%")
        data = {
            "model_name": f"{experiment_files.model_name_1}_vs_gaussian_baseline",
            "dataset_name": experiment_files.dataset_name,
            "delta_accuracy": gaussian_delta_perc,
            "delta_auc": auc1 - gaussian_auc_1,
        }
        write_results_to_jsonl(data)
    if random_acc_1:
        random_delta_perc = ((acc1 - random_acc_1) / random_acc_1) * 100
        print(f"Delta accuracy between {experiment_files.model_name_1} and its random baseline: {random_delta_perc:.4f}%")
        data = {
            "model_name": f"{experiment_files.model_name_1}_vs_random_baseline",
            "dataset_name": experiment_files.dataset_name,
            "delta_accuracy": random_delta_perc,
            "delta_auc": auc1 - random_auc_1,
        }
        write_results_to_jsonl(data)

    return ExperimentResults(
        acc1=acc1,
        auc1=auc1,
        acc2=acc2,
        auc2=auc2,
        acc3=acc3,
        auc3=auc3,
        gaussian_acc_1=gaussian_acc_1,
        gaussian_auc_1=gaussian_auc_1,
        random_acc_1=random_acc_1,
        random_auc_1=random_auc_1,
        gaussian_acc_2=gaussian_acc_2,
        gaussian_auc_2=gaussian_auc_2,
        random_acc_2=random_acc_2,
        random_auc_2=random_auc_2,
        gaussian_acc_3=gaussian_acc_3,
        gaussian_auc_3=gaussian_auc_3,
        random_acc_3=random_acc_3,
        random_auc_3=random_auc_3,
        gaussian_delta_perc=gaussian_delta_perc,
        random_delta_perc=random_delta_perc,
    )
    
    
    
    # try:
    #     model_main(
    #         experiment_files.model_name_3,
    #         experiment_files=experiment_files,
    #         baseline=experiment_files.baseline_settings_3,
    #     )
    # except Exception as e:
    #     print(
    #         f"Error processing baseline model {experiment_files.model_name_3}: {e}"
    #     )

    # if experiment_files.baseline_settings is not None:
    #     print("trying to combine baseline and MIP")
    #     try:
    #         model_main_combine(
    #             experiment_files.model_name_1, experiment_files=experiment_files
    #         )
    #     except Exception as e:
    #         print(
    #             f"Error processing combined model {experiment_files.model_name_1}: {e}"
    #         )
    #     try:
    #         model_main_combine(
    #             experiment_files.model_name_2, experiment_files=experiment_files
    #         )
    #     except Exception as e:
    #         print(
    #             f"Error processing combined model {experiment_files.model_name_2}: {e}"
    #         )


MTBA_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="backdoor_mtba",
    base_dir="backdoor/mtba",
    output_category="summary/backdoor",
)

SLEEPER_AGENT_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="backdoor_sleeper_agent",
    base_dir="backdoor/sleeper_agent",
    output_category="summary/backdoor",
)
VPI_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="backdoor_vpi",
    base_dir="backdoor/vpi",
    output_category="summary/backdoor",
)


AUTODAN_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="jailbreaking_autodan",
    base_dir="jailbreaking/autodan",
    output_category="summary/jailbreaking",
)


GCG_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="jailbreaking_gcg",
    base_dir="jailbreaking/gcg",
    output_category="summary/jailbreaking",
)

PAP_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="jailbreaking_pap",
    base_dir="jailbreaking/pap",
    output_category="summary/jailbreaking",
)
QUESTIONS1000_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="lie_questions_1000",
    base_dir="lie/questions_1000",
    output_category="summary/lie",
)

SCIQ_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="lie_sciq", base_dir="lie/sciq", output_category="summary/lie"
)

WIKIDATA_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="lie_wikidata",
    base_dir="lie/wikidata",
    output_category="summary/lie",
)


SURGEAI_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="toxicity_surge_ai",
    base_dir="toxicity/surge_ai_toxicity",
    output_category="summary/toxicity",
)

SOCIALCHEM101_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="toxicity_social_chem_101",
    base_dir="toxicity/social_chem_101",
    output_category="summary/toxicity",
)

JIGSAW_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="toxicity_jigsaw",
    base_dir="toxicity/jigsaw_conversation_ai",
    output_category="summary/toxicity",
)
CTBA_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="backdoor_ctba",
    base_dir="backdoor/ctba",
    output_category="summary/backdoor",
)
CIVIL_COMMENTS_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="toxicity_civil_comments",
    base_dir="toxicity/civil_comments",
    output_category="summary/toxicity",
)


REAL_TOXIC_PROMPTS_EXPERIMENT_FILES = ExperimentFiles(
    dataset_name="toxicity_real_toxicity_prompts",
    base_dir="toxicity/real_toxicity_prompts",
    output_category="summary/toxicity",
)

if __name__ == "__main__":
    # main(SOCIALCHEM101_EXPERIMENT_FILES)
    # main(JIGSAW_EXPERIMENT_FILES)
    # main(CIVIL_COMMENTS_EXPERIMENT_FILES)
    ### old ####

    experiment_files = [
        QUESTIONS1000_EXPERIMENT_FILES,
        WIKIDATA_EXPERIMENT_FILES,
        SCIQ_EXPERIMENT_FILES,
        SLEEPER_AGENT_EXPERIMENT_FILES,
        MTBA_EXPERIMENT_FILES,
        VPI_EXPERIMENT_FILES,
        AUTODAN_EXPERIMENT_FILES,
        GCG_EXPERIMENT_FILES,
        PAP_EXPERIMENT_FILES,
        SURGEAI_EXPERIMENT_FILES,
        CTBA_EXPERIMENT_FILES,
        REAL_TOXIC_PROMPTS_EXPERIMENT_FILES,
    ]

    acc1s = []
    auc1s = []
    acc2s = []
    auc2s = []
    acc3s = []
    auc3s = []
    accu_gaussian_1 = []
    auc_gaussian_1 = []
    accu_random_1 = []
    auc_random_1 = []
    accu_gaussian_2 = []
    auc_gaussian_2 = []
    accu_random_2 = []
    auc_random_2 = []
    accu_gaussian_3 = []
    auc_gaussian_3 = []
    accu_random_3 = []
    auc_random_3 = []


    gaussian_deltas = []
    random_deltas = []


    for ef in experiment_files:
        print(f"Processing dataset: {ef.dataset_name}")
        results = main(ef)
        if results is not None:
            if results.gaussian_delta_perc is not None and results.random_delta_perc is not None:
                gaussian_deltas.append(results.gaussian_delta_perc)
                random_deltas.append(results.random_delta_perc)
            acc1s.append(results.acc1) if results.acc1 is not None else None
            auc1s.append(results.auc1) if results.auc1 is not None else None
        acc2s.append(results.acc2) if results.acc2 is not None else None
        auc2s.append(results.auc2) if results.auc2 is not None else None
        acc3s.append(results.acc3) if results.acc3 is not None else None
        auc3s.append(results.auc3) if results.auc3 is not None else None
        if results.gaussian_acc_1 is not None:
            accu_gaussian_1.append(results.gaussian_acc_1)
        if results.gaussian_auc_1 is not None:
            auc_gaussian_1.append(results.gaussian_auc_1)
        if results.random_acc_1 is not None:
            accu_random_1.append(results.random_acc_1)
        if results.random_auc_1 is not None:
            auc_random_1.append(results.random_auc_1)
        if results.gaussian_acc_2 is not None:
            accu_gaussian_2.append(results.gaussian_acc_2)
        if results.gaussian_auc_2 is not None:
            auc_gaussian_2.append(results.gaussian_auc_2)
        if results.random_acc_2 is not None:
            accu_random_2.append(results.random_acc_2)
        if results.random_auc_2 is not None:
            auc_random_2.append(results.random_auc_2)
        if results.gaussian_acc_3 is not None:
            accu_gaussian_3.append(results.gaussian_acc_3)
        if results.gaussian_auc_3 is not None:
            auc_gaussian_3.append(results.gaussian_auc_3)
        if results.random_acc_3 is not None:
            accu_random_3.append(results.random_acc_3)
        if results.random_auc_3 is not None:
            auc_random_3.append(results.random_auc_3)   

    import numpy as np
    if len(gaussian_deltas) > 0:
        print(f"Median delta accuracy across datasets (gaussian baseline): {np.median(gaussian_deltas):.4f}%")
    if len(random_deltas) > 0:
        print(f"Median delta accuracy across datasets (random baseline): {np.median(random_deltas):.4f}%")

    print("Done processing all datasets.")
    print(f"Average accuracy model 1: {np.mean(acc1s):.4f} ± {np.std(acc1s):.4f}")
    print(f"Average AUC model 1: {np.mean(auc1s):.4f} ± {np.std(auc1s):.4f}")
    print(f"Average accuracy model 2: {np.mean(acc2s):.4f} ± {np.std(acc2s):.4f}")
    print(f"Average AUC model 2: {np.mean(auc2s):.4f} ± {np.std(auc2s):.4f}")
    print(f"Average accuracy model 3: {np.mean(acc3s):.4f} ± {np.std(acc3s):.4f}")
    print(f"Average AUC model 3: {np.mean(auc3s):.4f} ± {np.std(auc3s):.4f}")
    if len(accu_gaussian_1) > 0:
        print(f"Average accuracy gaussian baseline model 1: {np.mean(accu_gaussian_1):.4f} ± {np.std(accu_gaussian_1):.4f}")
        print(f"Average AUC gaussian baseline model 1: {np.mean(auc_gaussian_1):.4f} ± {np.std(auc_gaussian_1):.4f}")
    if len(accu_random_1) > 0:
        print(f"Average accuracy random baseline model 1: {np.mean(accu_random_1):.4f} ± {np.std(accu_random_1):.4f}")
        print(f"Average AUC random baseline model 1: {np.mean(auc_random_1):.4f} ± {np.std(auc_random_1):.4f}")
    if len(accu_gaussian_2) > 0:
        print(f"Average accuracy gaussian baseline model 2: {np.mean(accu_gaussian_2):.4f} ± {np.std(accu_gaussian_2):.4f}")
        print(f"Average AUC gaussian baseline model 2: {np.mean(auc_gaussian_2):.4f} ± {np.std(auc_gaussian_2):.4f}")
    if len(accu_random_2) > 0:
        print(f"Average accuracy random baseline model 2: {np.mean(accu_random_2):.4f} ± {np.std(accu_random_2):.4f}")
        print(f"Average AUC random baseline model 2: {np.mean(auc_random_2):.4f} ± {np.std(auc_random_2):.4f}")
    if len(accu_gaussian_3) > 0:
        print(f"Average accuracy gaussian baseline model 3: {np.mean(accu_gaussian_3):.4f} ± {np.std(accu_gaussian_3):.4f}")
        print(f"Average AUC gaussian baseline model 3: {np.mean(auc_gaussian_3):.4f} ± {np.std(auc_gaussian_3):.4f}")
    if len(accu_random_3) > 0:
        print(f"Average accuracy random baseline model 3: {np.mean(accu_random_3):.4f} ± {np.std(accu_random_3):.4f}")
        print(f"Average AUC random baseline model 3: {np.mean(auc_random_3):.4f} ± {np.std(auc_random_3):.4f}")