"""
Result class is used to generate tables and visualizations from the results.
"""
import os
import pickle

from collections import defaultdict, Counter
import json
import logging

from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
from scipy.stats import spearmanr

import seaborn as sns
from gmpy2 import hamdist

from module.experiment import Experiment
from module.model.rset_wrapper import RsetWrapper
from matplotlib.ticker import FormatStrFormatter

from module.utils import Hparams



class Result:
    def __init__(self, files, args):
        ## Storing files instead of experiment to save memory
        """
        Files are structures like:
        {
            "root": {
                "dataset": {
                    "model": [filename1, filename2, ...]
                }
            }
        """
        logger = logging.getLogger("Result.__init__")
        ## Files structure are:
        # for root, datasets in files.items():
        #     for dataset, models in datasets.items():
        #         for model_name, filenames in models.items():
        #             models_to_remove = []
        #             for filename in filenames:
        #                 logger.debug(f"Loading {filename}")
        #                 temp = Experiment()
        #                 temp.load_experiment(filename)
                        
        #                 if not temp.result:
        #                     models_to_remove.append(model_name)
        #                     break
        #         for model_name in models_to_remove:
        #             del models[model_name]

        self.files = files
        self.result_dir = args.result_dir
        self.output_dir = args.output_dir
        self.fig_dir = args.fig_dir
        self.experiment = Experiment()
        self.cmap = plt.get_cmap("tab20")
        self.color_map = {
            'RSET_kan': self.cmap(0),
            'RSET_min': "aqua",
            'RSET_opt': self.cmap(18),
            'RSET_max': "blue",
            'CART': "gray",
            'GROOT': self.cmap(12),
            'ROCT-N': self.cmap(6),
            'ROCT-V': "pink",
            'FPRDT': "salmon",
        }

    def generate_table(self, skip_time=True):
        results = []
        results_mean = []
        results_std = []
        results_fit_time = []
        logger = logging.getLogger("Result.generate_table")
        train_time = []
        for root, datasets in self.files.items():
            for dataset, models in datasets.items():
                logger.debug(f"Loading {dataset}")
                for model_name, filenames in models.items():
                    logger.debug(f"Loading {model_name}")
                    data = defaultdict(list)
                    fit_times = []
                    for filename in filenames:
                        self.experiment.load_experiment(filename)
                        fit_time = self.experiment.fit_time
                        fit_times.append(fit_time)
                        if self.experiment.result != {}:
                            for key, value in self.experiment.result.items():
                                if type(value) == np.ndarray:
                                    continue
                                data[key].append(value)

                    invalid_keys = []
                    for key, value in data.items():
                        if len(value) != 5:
                            logger.warning(f"Dataset: {dataset}, Model: {model_name}, Key: {key}, Length: {len(value)}")
                            invalid_keys.append(key)

                    for key in invalid_keys:
                        del data[key]
                    metrics_df = pd.DataFrame(data)
                    non_numeric_cols = metrics_df.columns[~metrics_df.apply(lambda col: pd.to_numeric(col, errors='coerce').notnull().all())]
                    if non_numeric_cols.size > 0:
                        print("Non-numeric columns:", non_numeric_cols.tolist())
                        metrics_df = metrics_df.drop(columns=non_numeric_cols)

                    if len(metrics_df) == 0:
                        continue
                    
                    time_cols = metrics_df.columns[metrics_df.columns.str.contains("time")]
                    if time_cols.size > 0 and skip_time:
                        metrics_df = metrics_df.drop(columns=time_cols)

                    for col in metrics_df.columns:
                        metrics_df[col] = pd.to_numeric(metrics_df[col])


                    metrics_mean = metrics_df.mean()
                    metrics_std = metrics_df.std()
                    fit_time_mean = np.mean(fit_times)
                    # Combine the mean and std into a single entry
                    formatted_results = {
                        "Dataset": dataset,
                        "Model": model_name
                    }
                    formatted_results_mean = {
                        "Dataset": dataset,
                        "Model": model_name
                    }
                    formatted_results_std = {
                        "Dataset": dataset,
                        "Model": model_name
                    }
                    formatted_fit_time = {
                        "Dataset": dataset,
                        "Model": model_name,
                    }
                    for metric in metrics_mean.index:
                        mean = metrics_mean[metric]
                        std = metrics_std[metric]
                        formatted_results[metric] = f"{mean:.3f} ± {std:.3f}"
                        formatted_results_mean[metric] = mean
                        formatted_results_std[metric] = std
                    formatted_fit_time['fit_time'] = fit_time_mean

                    # Append the formatted results to the list
                    results.append(formatted_results)
                    results_mean.append(formatted_results_mean)
                    results_std.append(formatted_results_std)
                    results_fit_time.append(formatted_fit_time)
        results_df = pd.DataFrame(results)
        results_mean_df = pd.DataFrame(results_mean)
        results_std_df = pd.DataFrame(results_std)
        results_fit_time = pd.DataFrame(results_fit_time)
        print(results_df)
        return results_df, results_mean_df, results_std_df, results_fit_time

    # def plot_scatter(self, metric1, metric2):
    #     for root, datasets in self.files.items():
    #         for dataset, models in datasets.items():
    #             plt.figure(figsize=(8,6))
    #             for model_name, filenames in models.items():
    #                 score1_list, score2_list = [], []
    #                 for file in filenames:
    #                     self.experiment.load
    #                     score1 = self.experiment.result[metric1]
    #                     score2 = self.experiment.result[metric2]
    #                     if isinstance(self.experiment.model, RsetWrapper):
    #                         score1 = score1[:, 0]
    #                         score2 = score2[:, 0]
    #                         score1_list.extend(score1)
    #                         score2_list.extend(score2)
    #                     else:
    #                         score1_list.append(score1)
    #                         score2_list.append(score2)
    #                 if not isinstance(self.experiment.model, RsetWrapper):
    #                     score1 = np.mean(score1_list)
    #                     score2 = np.mean(score2_list)
    #                 plt.scatter(score2, score1, label=f"{self.experiment.model_type}")
    #             plt.ylabel(metric1)
    #             plt.xlabel(metric2)
    #             plt.title(f"{metric1} vs {metric2}")
    #             plt.legend()
    #             out_file = f"{self.result_dir}/{root}/{dataset}_{metric1}_{metric2}.png"
    #             os.makedirs(f"{self.result_dir}/{root}", exist_ok=True)
    #             plt.savefig(out_file, dpi=300, bbox_inches="tight")
    #             plt.close()

    def plot_density(self, metrics):
        logger = logging.getLogger("Result.plot_density")
        for dataset, models in self.files.items():
            model_scores = defaultdict(list)
            density_scores = []
            logger.info(f"Generating density plot for {metrics} on {dataset}")
            for model_name, folds in models.items():
                for fold_idx, experiments in folds.items():

                    result_file = experiments["result"]
                    with open(result_file, "rb") as f:
                        result = pickle.load(f)

                    if model_name == "RSET":
                        if result.get("optimal_tree_kantch_evasion_attack_accuracy_on_test_from_optimal_tree") is None:
                            logger.warning(f"Skipping {dataset} - {model_name} - {fold_idx} as no optimal tree found")
                            continue
                        optimal_score = result["optimal_tree_kantch_evasion_attack_accuracy_on_test_from_optimal_tree"]
                        min_score = result["min_leaf_tree_kantch_evasion_attack_accuracy_on_test_from_optimal_tree"]
                        max_score = result["max_leaf_tree_kantch_evasion_attack_accuracy_on_test_from_optimal_tree"]
                        model_scores["RSET_opt"].append(optimal_score)
                        model_scores["RSET_min"].append(min_score)
                        model_scores["RSET_max"].append(max_score)
                        # sel_score = result[f"kantch_tree_kantch_evasion_attack_accuracy_on_test_from_optimal_tree"]
                        score = result["kantch_test_scores"]
                        density_scores.append(score)

                    else:
                        score = result[metrics]
                        model_scores[model_name].append(score)
            
            fig, ax = plt.subplots(figsize=(30, 6))
            density = np.concatenate(density_scores)
            sns.histplot(density, fill=True, color=self.cmap(1), common_norm=True, kde=True, stat='count', label=f"{model_name}", edgecolor=None, ax=ax)
            y_min, y_max = ax.get_ylim()
            yticks = np.linspace(y_min, y_max, num=5)  # Adjust 'num' to set the number of ticks
            ax.set_yticks(yticks)
            ax.set_yticklabels([f"{int(ytick//5)}" for ytick in yticks])
            orders = ["RSET_opt", "RSET_min", "RSET_max"]
            legend_entries = []
            for model_name in orders:
                mean = np.mean(model_scores[model_name])
                line = ax.axvline(mean, color=self.color_map[model_name], linestyle='-', label = f"{model_name}: {str(round(mean, 2)).lstrip('0')}", ymax=0.95, lw=5)
                legend_entries.append(line)

            ax.set_xlabel("Accuracy", fontsize=56)
            ax.set_ylabel("Count", fontsize=56)
            ax.set_title(f"{dataset}", fontsize=64)
            ax.tick_params(axis='x', labelsize=48)
            ax.tick_params(axis='y', labelsize=48)
            fig.tight_layout()
            fig.legend(
                handles=legend_entries,
                loc="upper center",
                bbox_to_anchor=(0.5, -0),  # Position below the plot
                fontsize=56,
                ncol=3,
                frameon=False,
                columnspacing=0.5,
            )

            # orders = ["CART", "ROCT-V", "RSET_opt", "FPRDT", "ROCT-N", "RSET_min", "GROOT",  "RSET_kan", "RSET_max", ]
            # # Store legend entries
            # for model_name in orders:
            #     mean = np.mean(model_scores[model_name])
            #     line = ax.axvline(mean, color=self.color_map[model_name], linestyle='-', label = f"{model_name}: {str(round(mean, 2)).lstrip('0')}", ymax=0.95, lw=5)
            #     legend_entries.append(line)
            
            
            fig.tight_layout()
            # Save the plot as a PNG file
            out_file = f"{self.output_dir}/{self.fig_dir}/{dataset}.png"
            os.makedirs(f"{self.output_dir}/{self.fig_dir}", exist_ok=True)
            fig.savefig(out_file, dpi=300, bbox_inches="tight")
            plt.close()

    def report_leaf_dist(self) -> None:
        logger = logging.getLogger("Result.report_leaf_dist")
        for dataset, models in self.files.items():
            for model_name, folds in models.items():
                for fold_idx, exp_files in folds.items():
                    logger.info(f"Loading {dataset} - {model_name} - {fold_idx}")
                    model_file = exp_files["model"]
                    result_file = exp_files["result"]
                    param_file = exp_files["param"]
                    with open(result_file, "rb") as f:
                        result = pickle.load(f)
                    if result.get("optimal_leaf_box") is None:
                        logger.warning(f"Skipping {dataset} - {model_name} - {fold_idx} as no optimal tree found")
                        continue
                    opt_leaf_box = result["optimal_leaf_box"]
                    best_leaf_box = result["best_leaf_box"]

                    logger.info("Distribution of leaf boxes:")
                    logger.info(f"Optimal leaf box: {Counter(opt_leaf_box)}")
                    logger.info(f"Best leaf box: {Counter(best_leaf_box)}")

                    box_counts = Counter(opt_leaf_box)
                    most_common_box = box_counts.most_common(1)[0][0]
                    logger.info(f"Most common box: {most_common_box} with count: {box_counts[most_common_box]}")

                    indices = np.where(opt_leaf_box == most_common_box)[0]

                    corresponding_leafs = best_leaf_box[indices]
                    distribution = Counter(corresponding_leafs)
                    logger.info(f"Distribution of best tree leaves at those indices: {dict(distribution)}")

                    # Step 5: Visualize
# plt.bar(distribution.keys(), distribution.values(), tick_label=list(distribution.keys()))
# plt.title(f'Distribution of arr2 labels where arr1 == {largest_label}')
# plt.xlabel('arr2 label')
# plt.ylabel('Count')
# plt.grid(axis='y')
# plt.show()


    def plot_scatter_sparsity(self) -> None:
        logger = logging.getLogger("Result.plot_scatter_sparsity")
        for dataset, models, in self.files.items():
            for model_name, folds in models.items():
                combined_adv_scores = []
                combined_num_leaves = []
                logger.info("Generating scatter plot for %s - %s", dataset, model_name)
                for fold_idx, exp_files in folds.items():
                    model_file = exp_files["model"]
                    result_file = exp_files["result"]
                    param_file = exp_files["param"]

                    with open(result_file, "rb") as f:
                        result = pickle.load(f)
                    with open(model_file, "rb") as f:
                        model = pickle.load(f)

                    optimal_idx = model.special_tree["optimal_tree"]
                    best_kantch_idx = model.best_tree["best_Kantch_test_tree"]

                    adv_scores = result["kantch_test_scores"]
                    num_leaves = [model.get_tree_num_leaves(model_idx) for model_idx in range(model.ntrees())]

                    combined_adv_scores.extend(adv_scores)
                    combined_num_leaves.extend(num_leaves)

                    plt.scatter(num_leaves, adv_scores, label=model_name, s=10)
                    plt.scatter(num_leaves[optimal_idx], adv_scores[optimal_idx], color="red", label="Optimal Tree")
                    plt.scatter(num_leaves[best_kantch_idx], adv_scores[best_kantch_idx], color="purple", label="Best Kantch Tree")
                    plt.title(f"{dataset} - {model_name} - {fold_idx}")
                    plt.xlabel("Number of Leaves")
                    plt.ylabel("Adv Score")
                    plt.legend(loc="upper left")
                    out_file = f"{self.output_dir}/{self.fig_dir}/sparsity/{dataset}_{model_name}_{fold_idx}.png"
                    os.makedirs(f"{self.output_dir}/{self.fig_dir}/sparsity", exist_ok=True)
                    plt.savefig(out_file, dpi=300, bbox_inches="tight")
                    plt.close()
                    logger.info("Saved plot to %s", out_file)

                plt.scatter(combined_num_leaves, combined_adv_scores, label=model_name, s=10)
                plt.title(f"{dataset} - {model_name} - Combined")
                plt.xlabel("Number of Leaves")
                plt.ylabel("Adv Score")
                out_file = f"{self.output_dir}/{self.fig_dir}/sparsity/{dataset}_{model_name}_combined.png"
                plt.savefig(out_file, dpi=300, bbox_inches="tight")
                plt.close()

    def plot_lineplot_patterns(self) -> None:
        logger = logging.getLogger("Result.plot_lineplot_patterns")
        subplot_data_accumulator = defaultdict(dict)

        for dataset, models in self.files.items():
            for model_name, folds in models.items():
                logger.info("Processing %s - %s", dataset, model_name)

                combined_distance = []
                combined_adv_scores = []
                optimal_indices = []
                best_kantch_indices = []
                min_max_dist = float("inf")
                total_samples = 0

                for fold_idx, exp_files in folds.items():
                    with open(exp_files["result"], "rb") as f:
                        result = pickle.load(f)
                    with open(exp_files["model"], "rb") as f:
                        model = pickle.load(f)
                    if fold_idx == "0":
                        with open(exp_files["param"], "rb") as f:
                            hparam = Hparams()
                            hparam.set_state(pickle.load(f))
                            logger.info("Model param: %s", hparam.model_params)

                    optimal_idx = model.special_tree["optimal_tree"]
                    best_kantch_idx = model.best_tree["best_Kantch_test_tree"]

                    test_preds = result["test_preds"]
                    adv_scores = result["kantch_test_scores"]

                    distances = [hamdist(pattern, test_preds[optimal_idx]) for pattern in test_preds]
                    min_max_dist = min(min_max_dist, max(distances))
                    total_samples += len(test_preds[0])

                    # Accumulate for combined analysis
                    combined_distance.extend(distances)
                    combined_adv_scores.extend(adv_scores)
                    offset = len(combined_distance) - len(distances)
                    optimal_indices.append(optimal_idx + offset)
                    best_kantch_indices.append(best_kantch_idx + offset)

                    # Plot per fold
                    dist_to_scores = defaultdict(list)
                    for d, score in zip(distances, adv_scores):
                        dist_to_scores[d].append(score)

                    avg_std = sorted(
                        (d, np.mean(s), np.std(s)) for d, s in dist_to_scores.items()
                    )
                    avg_scores = np.array([(d, m) for d, m, _ in avg_std])
                    std_scores = np.array([(d, s) for d, _, s in avg_std])

                    plt.figure()
                    plt.plot(avg_scores[:, 0], avg_scores[:, 1], label=model_name, color="blue", lw=2)
                    plt.fill_between(avg_scores[:, 0], avg_scores[:, 1] - std_scores[:, 1], avg_scores[:, 1] + std_scores[:, 1], alpha=0.2, color="blue")
                    plt.scatter(distances[optimal_idx], adv_scores[optimal_idx], color="red", label="Optimal Tree")
                    plt.scatter(distances[best_kantch_idx], adv_scores[best_kantch_idx], color="purple", label="Best Kantch Tree")
                    corr = spearmanr(distances, adv_scores).correlation
                    plt.title(f"{dataset} - {model_name} - {fold_idx} - Spearman: {corr:.2f}")
                    plt.xlabel("Distance to Optimal Tree")
                    plt.ylabel("Adversarial Score")
                    plt.legend(loc="upper left")
                    out_path = os.path.join(self.output_dir, self.fig_dir, "pattern_line")
                    os.makedirs(out_path, exist_ok=True)
                    plt.savefig(os.path.join(out_path, f"{dataset}_{model_name}_{fold_idx}.png"), dpi=300, bbox_inches="tight")
                    plt.close()

                # Combined plot
                combined_distance = np.array(combined_distance)
                combined_adv_scores = np.array(combined_adv_scores)
                avg_samples = total_samples / len(folds)
                combined_corr = spearmanr(combined_distance, combined_adv_scores)

                optimal_dists = combined_distance[optimal_indices] / avg_samples
                optimal_scores = combined_adv_scores[optimal_indices]
                best_dists = combined_distance[best_kantch_indices] / avg_samples
                best_scores = combined_adv_scores[best_kantch_indices]

                def summarize(arr): return np.mean(arr), np.std(arr)

                opt_dist_mean, opt_dist_std = summarize(optimal_dists)
                opt_score_mean, opt_score_std = summarize(optimal_scores)
                best_dist_mean, best_dist_std = summarize(best_dists)
                best_score_mean, best_score_std = summarize(best_scores)

                dist_to_scores = defaultdict(list)
                for d, s in zip(combined_distance, combined_adv_scores):
                    if d <= min_max_dist:
                        dist_to_scores[d].append(s)

                avg_std = sorted(
                    (d / avg_samples, np.mean(s), np.std(s)) for d, s in dist_to_scores.items()
                )
                avg_scores = np.array([(d, m) for d, m, _ in avg_std])
                std_scores = np.array([(d, s) for d, _, s in avg_std])

                plt.figure(figsize=(8, 8))
                plt.plot(avg_scores[:, 0], avg_scores[:, 1], color="blue", lw=2)
                plt.fill_between(avg_scores[:, 0], avg_scores[:, 1] - std_scores[:, 1], avg_scores[:, 1] + std_scores[:, 1], alpha=0.2, color="blue")
                plt.errorbar(opt_dist_mean, opt_score_mean, xerr=opt_dist_std, yerr=opt_score_std, fmt='.', color='red', capsize=5, label="Optimal Tree (mean ± std)")
                plt.errorbar(best_dist_mean, best_score_mean, xerr=best_dist_std, yerr=best_score_std, fmt='.', color='purple', capsize=5, label="Best Kantch Tree (mean ± std)")
                plt.title(f"{dataset} Spearman: {combined_corr.correlation:.2f} (p: {combined_corr.pvalue:.2f})", fontsize=18)
                plt.xlabel("Normalized Distance to Optimal Tree", fontsize=16)
                plt.ylabel("Adversarial Score", fontsize=16)
                plt.legend(loc="upper left")
                plt.savefig(os.path.join(out_path, f"{dataset}_{model_name}_combined.png"), dpi=300, bbox_inches="tight")
                plt.close()

                # Save summary for final subplot
                subplot_data_accumulator[dataset][model_name] = {
                    "optimal_tree": dict(distance=opt_dist_mean, score=opt_score_mean, std_distance=opt_dist_std, std_score=opt_score_std),
                    "best_kantch_tree": dict(distance=best_dist_mean, score=best_score_mean, std_distance=best_dist_std, std_score=best_score_std),
                    "spearman_corr": combined_corr.correlation,
                    "p_value": combined_corr.pvalue,
                    "avg_score_per_distance": avg_scores,
                    "std_score_per_distance": std_scores,
                }

        # Final combined subplot for all datasets
        fig, axes = plt.subplots(1, len(subplot_data_accumulator), figsize=(8 * len(subplot_data_accumulator), 8), squeeze=False)
        for i, (dataset, data) in enumerate(sorted(subplot_data_accumulator.items())):
            ax = axes[0, i]
            for model_name, info in data.items():
                avg_scores = info["avg_score_per_distance"]
                std_scores = info["std_score_per_distance"]
                ax.plot(avg_scores[:, 0], avg_scores[:, 1], lw=2)
                ax.fill_between(avg_scores[:, 0], avg_scores[:, 1] - std_scores[:, 1], avg_scores[:, 1] + std_scores[:, 1], alpha=0.2)

                ax.errorbar(info["optimal_tree"]["distance"], info["optimal_tree"]["score"],
                            xerr=info["optimal_tree"]["std_distance"], yerr=info["optimal_tree"]["std_score"],
                            fmt='.', color='red', capsize=5, label="Optimal Tree (mean ± std)")
                ax.errorbar(info["best_kantch_tree"]["distance"], info["best_kantch_tree"]["score"],
                            xerr=info["best_kantch_tree"]["std_distance"], yerr=info["best_kantch_tree"]["std_score"],
                            fmt='.', color='purple', capsize=5, label="Best Kantch Tree (mean ± std)")

            ax.set_title(f"{dataset} - Spearman: {info['spearman_corr']:.2f} (p: {info['p_value']:.2f})", fontsize=28)
            ax.grid(True)
            ax.tick_params(axis='both', labelsize=24)

        handles, labels = axes[0, 0].get_legend_handles_labels()
        fig.legend(
            handles,
            labels,
            loc="upper center",
            bbox_to_anchor=(0.5, 1.05),
            fontsize=28,
            ncol=3,
            frameon=False
        )
        for ax in axes.flatten():
            ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))


        fig.text(0.5, 0.04, 'Normalized Distance to Optimal Tree', ha='center', fontsize=28)
        fig.text(0.04, 0.5, 'Adversarial Score', va='center', rotation='vertical', fontsize=28)
        plt.tight_layout(rect=[0.05, 0.075, 1, 0.925])
        plt.savefig(os.path.join(out_path, "combined_all_datasets.png"), dpi=300, bbox_inches="tight")
        plt.close()


    def plot_scatter_patterns(self) -> None:
        logger = logging.getLogger("Result.plot_scatter")
        for dataset, models in self.files.items():
            for model_name, folds in models.items():
                combined_distance = []
                combined_adv_scores = []
                optimal_idices = []
                best_kantch_idices = []
                logger.info("Generating scatter plot for %s - %s", dataset, model_name)
                for fold_idx, exp_files in folds.items():
                    model_file = exp_files["model"]
                    result_file = exp_files["result"]
                    param_file = exp_files["param"]

                    with open(result_file, "rb") as f:
                        result = pickle.load(f)
                    
                    with open(model_file, "rb") as f:
                        model = pickle.load(f)
                    
                    optimal_idx = model.special_tree["optimal_tree"]
                    best_kantch_idx = model.best_tree["best_Kantch_test_tree"]

                    test_preds = result["test_preds"]
                    adv_scores = result["kantch_test_scores"]

                    num_samples = len(test_preds[0])

                    distances = [hamdist(pattern, test_preds[optimal_idx]) / num_samples for pattern in test_preds]
                    correlations = spearmanr(distances, adv_scores)
                    combined_distance.extend(distances)
                    combined_adv_scores.extend(adv_scores)
                    optimal_idices.append(optimal_idx + len(combined_distance) - len(distances))
                    best_kantch_idices.append(best_kantch_idx + len(combined_distance) - len(distances))

                    plt.scatter(distances, adv_scores, label=model_name, s=5, alpha=0.2,rasterized=True)
                    plt.scatter(distances[optimal_idx], adv_scores[optimal_idx], s=10, color="red", label="Optimal Tree")
                    plt.scatter(distances[best_kantch_idx], adv_scores[best_kantch_idx], s=20, color="purple", label="Best Kantch Tree")
                    plt.title(f"{dataset} - {model_name} - {fold_idx} - Spearman: {correlations.correlation:.2f}")
                    plt.xlabel("Normalized Distance to Optimal Tree")
                    plt.ylabel("Adv Score")

                    plt.legend(loc="upper left")
                    out_file = f"{self.output_dir}/{self.fig_dir}/pattern_norm/{dataset}_{model_name}_{fold_idx}.png"
                    os.makedirs(f"{self.output_dir}/{self.fig_dir}/pattern_norm", exist_ok=True)
                    plt.savefig(out_file, dpi=300, bbox_inches="tight")
                    plt.close()
                    logger.info("Saved plot to %s", out_file)

                combined_distance = np.array(combined_distance)
                combined_adv_scores = np.array(combined_adv_scores)
                correlations = spearmanr(combined_distance, combined_adv_scores)
                corr = correlations.correlation
                p_value = correlations.pvalue
                
                plt.scatter(combined_distance, combined_adv_scores, s=5, alpha=0.2, rasterized=True, label="RSET")

                plt.scatter(combined_distance[optimal_idices], combined_adv_scores[optimal_idices], s=10, color="red", label="Optimal Tree", rasterized=True)
                plt.scatter(combined_distance[best_kantch_idices], combined_adv_scores[best_kantch_idices], s=10, color="purple", label="Best Validation Tree", rasterized=True)

                plt.title(f"{dataset} - Spearman: {corr:.2f} (p-value: {p_value:.2f})")
                plt.xlabel("Normalized Prediction (Hamming) Distance to Optimal Tree")
                plt.ylabel("Adversarial Score")
                plt.legend(loc="upper left")
                out_file = f"{self.output_dir}/{self.fig_dir}/pattern_norm/{dataset}_{model_name}_combined.png"
                plt.savefig(out_file, dpi=300, bbox_inches="tight")
                plt.close()
                



                    



    # def exploration(self) -> None:
    #     """
    #     Exploration function for collecting stats and stuff
    #     """

    #     #TODO move this to results file
    #     logger = logging.getLogger("Metrics.exploration")
    #     adv_scores = eval_result["kantch_scores"]

    #     adv_scores = sorted(adv_scores, key=lambda x: x[0], reverse=True)

    #     X = self.encoder.transform(self.X_test) \
    #             if self.encoder is not None else self.X_test

    #     patterns = self.model.predicting_patterns(X)
    #     optimal_tree_idx = self.model.special_tree["optimal_tree"]
    #     logger.info("Optimal tree index: %s", optimal_tree_idx)
        
    #     dists = []
    #     scores = []
    #     for score, idx in adv_scores:
    #         idx = int(idx)
    #         # print(pattenrs[optimal_tree_idx], patterns[idx])
    #         dist = self.model.pattern_dist(patterns[optimal_tree_idx], patterns[idx])
    #         dists.append(dist)
    #         scores.append(score)
        
    #     dists = np.array(dists)
    #     dists = dists/np.max(dists)

    #     plt.plot(dists, label="Pred Dist with Opt Tree")
    #     plt.plot(scores, label="Adv Scores")
    #     plt.savefig("exploration.png")

    #     return
