import copy
from abc import ABC, abstractmethod
import networkx as nx
import concurrent.futures
from typing import List, Type
import json
import pandas as pd
from gnn_xai_common.datasets import BaseGraphDataset, RomeDataset
from gnn_xai_common import GraphSampler, Trainer
from scripts import *
from gnnboundary.utils.boundary_evaluator import BoundaryEvaluator
from gnnboundary.utils.boundary_generator import GraphGenerator
from scripts.default_configs import *
from scripts.experiments_logging import initialize_logger


logger = initialize_logger(log_file="experiments.log", log_level=logging.DEBUG)


class AbstractExperiment(ABC):
    def __init__(self):
        self.root_dir: str = "./experiments"

    @abstractmethod
    def run(self):
        pass

    @abstractmethod
    def plot(self):
        pass

    @abstractmethod
    def save(self):
        pass



class ReproduceAdjacencyExperiment(AbstractExperiment):
    """
    This experiment runs an adjacency analysis on the dataset specified in the argument
    """
    def __init__(self, dataset: BaseGraphDataset):
        super().__init__()

        logger.info("Setting up adjacency analysis experiment...")
        self.dataset = dataset
        self.config = AnalyzerConfig(model_checkpoint=CKPT_PATHS[dataset.name],
                                     model_kwargs=get_model_kwargs(dataset, dataset.name))
        self.analyzer = Analyzer(dataset, self.config)

        self.experiment_id = np.random.randint(0, 1000000)
        self.root_dir = f"{self.root_dir}/adjacency_analysis/{dataset.name}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        self.result = None
        self.finished_run = False
        self.figure = None
        logger.info("Finished setting up adjacency analysis experiment.")

    def run(self):
        logger.info("Running adjacency analysis...")
        self.result = self.analyzer.analyze_adjacency()
        logger.info("Finished adjacency analysis.")
        self.finished_run = True
        return self.result

    def plot(self, draw=False):
        figure = draw_matrix_adj(self.result[0], names=self.dataset.GRAPH_CLS.values(), fmt='.2f', return_fig=not draw,
                                 xlabel='', ylabel='', title="Adjacency Matrix")
        self.figure = figure

    def save(self):
        logger.info("Saving adjacency analysis results...")
        plt.savefig(f"{self.root_dir}/adjacency_matrix_{self.experiment_id}.png", bbox_inches="tight")
        # save results
        np.savetxt(f"{self.root_dir}/adjacency_matrix_{self.experiment_id}.csv", self.result[0], delimiter=",", fmt='%f')
        logger.info("Finished saving adjacency analysis results into " + f"{self.root_dir}/adjacency_analysis_results.pt")

    def load(self):
        pass

    def setup(self):
        pass


class ReproduceGraphGenerationExperiment(AbstractExperiment):
    """
    This experiment runs a boundary graph generation on the class pair specified in the argument
    """
    def __init__(self, dataset: BaseGraphDataset,
                 cls_pair: Tuple[int, int],
                 training_params: TrainingParams = None,
                 model_kwargs: Dict[str, int] = None,
                 ckpt_path: str = None,
                 model_architecture: object = None,
                 temperature: float = 0.5,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 criterion: WeightedCriterion = None
                 ):

        super().__init__()

        logger.info("Setting up graph generation experiment...")
        self.dataset = dataset

        if training_params is None:
            training_params = get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])

        self.training_params = training_params

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]


        sampling_config = SamplingTrainerConfig(
            cls_pair=[cls_pair],
            ckpt_path=ckpt_path,
            model_architecture=GCNClassifier if model_architecture is None else model_architecture,
            hidden_channels=model_kwargs["hidden_channels"],
            num_layers=model_kwargs["num_layers"],
            temperature=temperature,
            learn_node_feat=learn_node_feat,
            optimizer=torch.optim.SGD,
            lr=lr
        )

        self.trainer = SamplingTrainer(dataset, sampling_config)

        if criterion is not None:
            self.trainer.inject_criterion(criterion)

        self.experiment_id = (f"_temperature={temperature}_lr={lr}"
                              f"_iterations={training_params.iterations}_target_probs=[{str(list(training_params.target_probs[0].values()))}]"
                              f"w_budget_init={training_params.w_budget_init}_w_budget_dec={training_params.w_budget_dec}"
                              f"w_budget_inc={training_params.w_budget_inc}_ksamples={training_params.k_samples}"
                              f"eval_thresh={training_params.eval_threshold}_{np.random.randint(0, 100000000)}")

        self.root_dir = f"{self.root_dir}/simple_graph_generation/{dataset.name}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        self.cls_pair = cls_pair
        self.eval_results = None
        self.graph_generations = None
        self.figures = {}
        logger.info("Finished setting up graph generation experiment.")

    def run(self):
        logger.info(f"Running graph generation experiment for dataset {self.dataset.name} and adjacent class pairs...")
        results = self.trainer.train_evaluate_all(**self.training_params.__dict__)
        self.eval_results = results[0]
        self.graph_generations = results[1]
        logger.info("Finished graph generation experiment.")

    def plot(self, draw=False):
        if self.graph_generations is None:
            raise ValueError("No figures to plot. Run the experiment first.")

        for cls_pair in self.graph_generations:
            graph = self.graph_generations[cls_pair]
            fig, ax = plt.subplots()
            nx.draw(graph, ax=ax)
            if draw:
                plt.show()
            else:
                self.figures[cls_pair] = fig

    def save(self):
        for cls_pair in self.graph_generations:
            # save eval results
            np.savetxt(f"{self.root_dir}/{cls_pair[0]}_{cls_pair[1]}_logits.csv",
                       self.eval_results[cls_pair][0], delimiter=",", fmt='%f')
            np.savetxt(f"{self.root_dir}//{cls_pair[0]}_{cls_pair[1]}_probs.csv",
                       self.eval_results[cls_pair][1], delimiter=",", fmt='%f')

            self.figures[cls_pair].savefig(f"{self.root_dir}/{cls_pair[0]}_{cls_pair[1]}.png")


class AdjacencyAndBoundaryGraphGenerationExperiment(AbstractExperiment):
    """
    This experiment runs a boundary graph generation based on the adjacent class pairs found in the adjacency analysis.
    We can insert the number of graphs we want to generate for each class pair. The logits and probabilities are averaged
    based on the number of graphs ana saved. This experiment is meant to try out different parameter and criteria and then
    evaluate the results. It is not meant to reproduce the paper results. For that we will run an experiemtn that loads
    saved boundary graphs and evaluate them.
    """
    def __init__(self, dataset: BaseGraphDataset,
                 temperature: float,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 adj_threshold: float = 0.8,
                 model_kwargs=None,
                 training_params: TrainingParams = None,
                 ckpt_path=None,
                 model_architecture=None,
                 run_new_adjacency_analysis: bool = True,
                 criterion: WeightedCriterion = None,
                 num_graphs: int = 1,):
        super().__init__()
        logger.info("Setting up adjacency and boundary graph generation experiment...")
        self.dataset = dataset
        if training_params is None:
            training_params = get_default_training_params(dataset.name, 0, 1)

        self.num_graphs = num_graphs
        self.training_params = training_params

        self.experiment_id = (f"adj_threshold={adj_threshold}_temperature={temperature:.4f}_lr={lr:.4f}"
                              f"_iterations={training_params.iterations}_target_probs=[{str(list(training_params.target_probs.values()))}]"
                              f"w_budget_init={training_params.w_budget_init}_w_budget_dec={training_params.w_budget_dec:.3f}"
                              f"w_budget_inc={training_params.w_budget_inc:.3f}_ksamples={training_params.k_samples}"
                              f"eval_thresh={training_params.eval_threshold}_{np.random.randint(0, 100000000)}")

        self.adjacency_dir = f"{self.root_dir}/adjacency_analysis/{dataset.name}"
        self.root_dir = f"{self.root_dir}/graph_generation/{dataset.name}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        logger.info("Finding class pairs...")
        cls_pairs = self._find_class_pairs(adj_threshold, run_new_adjacency_analysis)
        logger.info("Finished finding class pairs." + f" Found {len(cls_pairs)} class pairs." + " Class pairs: " + str(cls_pairs))

        sampling_config = SamplingTrainerConfig(
            cls_pair=cls_pairs,
            ckpt_path=ckpt_path,
            model_architecture=GCNClassifier if model_architecture is None else model_architecture,
            hidden_channels=model_kwargs["hidden_channels"],
            num_layers=model_kwargs["num_layers"],
            temperature=temperature,
            learn_node_feat=learn_node_feat,
            optimizer=torch.optim.SGD,
            lr=lr
        )

        self.trainer = SamplingTrainer(dataset, sampling_config)
        if criterion is not None:
            self.trainer.inject_criterion(criterion)

        self.eval_results = None
        self.graph_generations = None
        self.figures = {}
        logger.info("Finished setting up adjacency and boundary graph generation experiment.")

    def _find_class_pairs(self, threshold: float, run_new_adjacency_analysis: bool):
        if not os.path.exists(self.adjacency_dir) or run_new_adjacency_analysis:
            logger.info("Could not find stored experiment. Will run new adjacency analysis...")
            adjacency_analysis = ReproduceAdjacencyExperiment(self.dataset)
            adjacency_analysis.run()
            adjacency_analysis.plot()
            adjacency_analysis.save()
            adjacency_matrix = adjacency_analysis.result[0]
        else:
            adjacency_matrix = np.loadtxt(f"{self.adjacency_dir}/adjacency_matrix.csv", delimiter=",")

        #extract adjacent pairs
        rows, cols = np.where((adjacency_matrix >= threshold) & (np.eye(adjacency_matrix.shape[0]) == 0))
        adjacent_classes = list(zip(rows, cols))

        unique_tuples = set()
        filtered_tuples = []

        for t in adjacent_classes:
            sorted_t = tuple(sorted(t))  # Sort each tuple to make it order-invariant
            if sorted_t not in unique_tuples:
                unique_tuples.add(sorted_t)
                filtered_tuples.append(t)

        return filtered_tuples

    def run(self):
        logger.info(f"Running graph generation experiment for dataset {self.dataset.name} and adjacent class pairs...")
        run_results = []
        for _ in range(self.num_graphs):
            results = self.trainer.train_evaluate_all(**self.training_params.__dict__)
            eval_results = results[0]

            run_results.append(eval_results)

        self.eval_results = {k: np.mean(np.array([x[k] for x in run_results]), axis=0) for k in run_results[0].keys()}
        self.graph_generations = results[1]
        logger.info("Finished graph generation experiment.")

    def plot(self, draw=False):
        if self.graph_generations is None:
            raise ValueError("No figures to plot. Run the experiment first.")

        for cls_pair in self.graph_generations:
            graph = self.graph_generations[cls_pair]
            fig, ax = plt.subplots()
            nx.draw(graph, ax=ax)
            if draw:
                plt.show()
            else:
                self.figures[cls_pair] = fig


    def save(self):
        for cls_pair in self.graph_generations:
            # save eval results
            np.savetxt(f"{self.root_dir}/{cls_pair[0]}_{cls_pair[1]}_logits.csv",
                       self.eval_results[cls_pair][0], delimiter=",", fmt='%f')
            np.savetxt(f"{self.root_dir}//{cls_pair[0]}_{cls_pair[1]}_probs_.csv",
                       self.eval_results[cls_pair][1], delimiter=",", fmt='%f')

            self.figures[cls_pair].savefig(f"{self.root_dir}/{cls_pair[0]}_{cls_pair[1]}.png")

        prob_results = {k: v[1] for k, v in self.eval_results.items()}
        df = pd.DataFrame.from_dict(prob_results, orient='index')
        df.to_csv(f"{self.root_dir}/prob_overview.csv")
        vals = []  #
        for k, v in prob_results.items():
            vals.append((prob_results[k][k[0]], prob_results[k][k[1]]))

        values = np.array(vals)
        labels = list(prob_results.keys())
        bar_width = 0.35
        x = np.arange(len(labels))

        fig, ax = plt.subplots()
        ax.bar(x - bar_width / 2, values[:, 0], bar_width,)
        ax.bar(x + bar_width / 2, values[:, 1], bar_width,)

        ax.set_ylabel('Probabilities')
        ax.set_title("Dataset:" +  self.dataset.name)
        ax.set_xticks(x)
        ax.set_xticklabels(labels)
        ax.legend()
        fig = plt.gcf()
        fig.savefig(f"{self.root_dir}/prob_overview.png")
        logger.info("Finished experiment: [Adjacency and boundary graph generation] and saved into " + self.root_dir)


class GraphGenerationAllClassPairsExperiment(AdjacencyAndBoundaryGraphGenerationExperiment):
    """
    This experiment runs a boundary graph generation based on all class pairs found no just the adjacent ones.
    We can insert the number of graphs we want to generate for each class pair. The logits and probabilities are averaged
    like in its parent class. Again, this experiment is meant to try out different parameter and criteria and then
    evaluate the results. It is not meant to reproduce the paper results. For that we will run an experiment that loads
    saved boundary graphs and evaluate them.
    """

    def __init__(self, dataset: BaseGraphDataset,
                 temperature: float,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 adj_threshold: float = 0.8,
                 model_kwargs=None,
                 training_params: TrainingParams = None,
                 ckpt_path=None,
                 model_architecture=None,
                 criterion: WeightedCriterion = None,
                 num_graphs: int = 1,
                 cls_pairs: dict = None,
        ):
        super().__init__(
            dataset,
            temperature,
            training_params=training_params,
            model_kwargs=model_kwargs,
            run_new_adjacency_analysis=True,
        )
        logger.info("Setting up adjacency and boundary graph generation experiment...")
        self.dataset = dataset
        if training_params is None:
            training_params = get_default_training_params(dataset.name, 0, 1)

        self.num_graphs = num_graphs
        self.training_params = training_params

        self.experiment_id = (f"adj_threshold={adj_threshold}_temperature={temperature:.4f}_lr={lr:.4f}"
                              f"_iterations={training_params.iterations}_target_probs=[{str(list(training_params.target_probs.values()))}]"
                              f"w_budget_init={training_params.w_budget_init}_w_budget_dec={training_params.w_budget_dec:.3f}"
                              f"w_budget_inc={training_params.w_budget_inc:.3f}_ksamples={training_params.k_samples}"
                              f"eval_thresh={training_params.eval_threshold}_{np.random.randint(0, 100000000)}")

        self.adjacency_dir = f"{self.root_dir}/adjacency_analysis/{dataset.name}"
        self.root_dir = f"{self.root_dir}/graph_generation_all_pairs/{dataset.name}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        if cls_pairs is None:
            cls_pairs = [(i, j) for i in range(len(dataset.GRAPH_CLS)) for j in range(len(dataset.GRAPH_CLS)) if i != j]

        sampling_config = SamplingTrainerConfig(
            cls_pair=cls_pairs,
            ckpt_path=ckpt_path,
            model_architecture=GCNClassifier if model_architecture is None else model_architecture,
            hidden_channels=model_kwargs["hidden_channels"],
            num_layers=model_kwargs["num_layers"],
            temperature=temperature,
            learn_node_feat=learn_node_feat,
            optimizer=torch.optim.SGD,
            lr=lr
        )

        self.trainer = SamplingTrainer(dataset, sampling_config)
        if criterion is not None:
            self.trainer.inject_criterion(criterion)

        self.eval_results = None
        self.graph_generations = None
        self.figures = {}


class ReproduceBoundaryComplexityExperimentOnePair(AbstractExperiment):
    """
    This experiment reproduces the boundary complexity experiment for a single class pair. It loads the saved boundary
    graphs and evaluates them. The experiment is meant to reproduce the paper results.
    The corresponding result is table 1 in the paper.
    """
    def __init__(self,
                 dataset: BaseGraphDataset,
                 cls_pair: Tuple[int, int],
                 num_graphs: int,
                 temperature: float,
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 graph_directory: str = None,
                 random_id: int = None):

        super().__init__()
        logger.info("Setting up boundary complexity experiment...")
        self.dataset = dataset
        self.num_graphs = num_graphs

        if training_params is None:
            training_params = get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])

        self.training_params = training_params

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"_{random_id}")

        self.root_dir = f"{self.root_dir}/boundary_complexity/{dataset.name}/{cls_pair[0]}_{cls_pair[1]}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat
        )

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = criterion if criterion is not None else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
        trainer = Trainer(
            sampler=sampler,
            discriminator=model,
            criterion=criterion,
            optimizer=(o := torch.optim.SGD(sampler.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
        )

        graph_directory = BOUNDARY_GRAPHS_DIRs_TO_DATASET[dataset.name][cls_pair] if graph_directory is None else graph_directory
        self.evaluator = BoundaryEvaluator(sampler,
                                           model,
                                           dataset,
                                           trainer,
                                           cls_pair=cls_pair,
                                           logger=logger,
                                           load_dir=graph_directory,
                                           temperature=temperature,
                                           max_nodes=max_nodes)
        self.complexity = None
        self.average_probs = None
        self.std_probs = None
        self.figs = []
        logger.info("Finished setting up boundary thickness experiment. Using saved boundary graphs from directory: " + graph_directory)

    def run(self):
        logger.info("Starting boundary complexity run...")
        training_params  = {k: v for k,v in self.training_params.__dict__.items() if k != "eval_threshold"}
        self.complexity = self.evaluator.boundary_complexity(self.num_graphs, use_pca=True, **training_params)
        self.average_probs, self.std_probs = self.evaluator.average_probs(self.num_graphs)
        logger.info("Finished boundary complexity run...")
        return self.complexity

    def plot(self):
        fig, ax = plt.subplots()
        ax.plot(self.complexity)
        ax.set_title(f"Average Probabilities over {self.num_graphs} graphs")
        ax.bar(np.arange(len(self.average_probs)), self.average_probs)
        self.figs.append(fig)

        fig, ax = plt.subplots()
        ax.plot(self.std_probs)
        ax.set_title(f"Standard Deviation probabilities over {self.num_graphs} graphs")
        ax.bar(np.arange(len(self.std_probs)), self.std_probs)
        self.figs.append(fig)

    def save(self):
        with open(self.root_dir + "/" +"complexity.txt", "w") as f:
            f.write(str(self.complexity))

        with open(self.root_dir + "/" +"average_probs.txt", "w") as f:
            f.write(str(self.average_probs))

        with open(self.root_dir + "/" +"std_probs.txt", "w") as f:
            f.write(str(self.std_probs))

        final_presentation = {"Complexity": f"{self.complexity:4f}",
                              "p(c1)" : f"{self.average_probs[0]:.4f} +- {self.std_probs[0]: .4f}",
                              "p(c2)" : f"{self.average_probs[1]:.4f} +- {self.std_probs[1]: .4f}"}

        with open(self.root_dir + "/" +"final_presentation.json", "w") as f:
            json.dump(final_presentation, f)

        self.figs[1].savefig(self.root_dir + "/std_deviation_analysis.png")
        self.figs[0].savefig(self.root_dir + "/probability_analysis.png")
        logger.info("Saved boundary complexity results into " + self.root_dir)


class ReproduceTable1PerDatasetExperiment(AbstractExperiment):
    def __init__(self,
                 dataset: BaseGraphDataset,
                 num_graphs: int,
                 temperature: float,
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 graph_directory=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 random_id: int = None):
        super().__init__()
        logger.info("Setting up table 1 reproduction experiment...")
        self.cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]
        self.dataset = dataset
        self.num_graphs = num_graphs

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"_{random_id}")

        self.root_dir = f"{self.root_dir}/reproduce_table_1/{dataset.name}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = {cls_pair : GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat)
            for cls_pair in self.cls_pairs }

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = {cls_pair: criterion if criterion is not None
        else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
                     for cls_pair in self.cls_pairs}

        trainer = {cls_pair: Trainer(
            sampler=sampler[cls_pair],
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(sampler[cls_pair].parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1))
            for cls_pair in self.cls_pairs}

        graph_directory = BOUNDARY_GRAPHS_DIRs_TO_DATASET[dataset.name] if graph_directory is None else {cls_pair: graph_directory + f"/{cls_pair[0]}-{cls_pair[1]}" for cls_pair in self.cls_pairs}

        self.training_params = {cls_pair: get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])
        if training_params is None else training_params for cls_pair in self.cls_pairs}

        self.evaluator = {cls_pair: BoundaryEvaluator(sampler[cls_pair],
                                             model,
                                             dataset,
                                             trainer[cls_pair],
                                             cls_pair=cls_pair,
                                             logger=logger,
                                             load_dir=graph_directory[cls_pair],
                                             temperature=temperature,
                                             max_nodes=max_nodes)
                                             for cls_pair in self.cls_pairs}

        self.complexity = {cls_pair: None for cls_pair in self.cls_pairs}
        self.average_probs = {cls_pair: None for cls_pair in self.cls_pairs}
        self.std_probs = {cls_pair: None for cls_pair in self.cls_pairs}
        self.figs = {cls_pair: [] for cls_pair in self.cls_pairs}
        logger.info(f"Finished setting up table 1 reproduction experiment for dataset {dataset.name}. "
                    f"Using saved boundary graphs from directory")

    def run(self):
        logger.info("Starting table 1 reproduction run...")
        for cls_pair in self.cls_pairs:
            self.complexity[cls_pair] = self.evaluator[cls_pair].boundary_complexity(self.num_graphs, use_pca=True, **self.training_params[cls_pair].__dict__)
            self.average_probs[cls_pair], self.std_probs[cls_pair] = self.evaluator[cls_pair].average_probs(self.num_graphs)
        logger.info("Finished table 1 reproduction run...")

    def plot(self):
        for cls_pair in self.cls_pairs:
            fig, ax = plt.subplots(figsize=(18,14))
            ax.plot(self.complexity[cls_pair])
            ax.set_title(f"Average Probabilities over {self.num_graphs} graphs")
            ax.bar(np.arange(len(self.average_probs[cls_pair])), self.average_probs[cls_pair])
            self.figs[cls_pair].append(fig)

            fig, ax = plt.subplots(figsize=(18,14))
            ax.plot(self.std_probs[cls_pair])
            ax.set_title(f"Standard Deviation probabilities over {self.num_graphs} graphs")
            ax.bar(np.arange(len(self.std_probs[cls_pair])), self.std_probs[cls_pair])
            self.figs[cls_pair].append(fig)

    def save(self):
        for cls_pair in self.cls_pairs:
            with open(self.root_dir + f"/{cls_pair[0]}_{cls_pair[1]}_complexity.txt", "w") as f:
                f.write(str(self.complexity[cls_pair]))

            with open(self.root_dir + f"/{cls_pair[0]}_{cls_pair[1]}_average_probs.txt", "w") as f:
                f.write(str(self.average_probs[cls_pair]))

            with open(self.root_dir + f"/{cls_pair[0]}_{cls_pair[1]}_std_probs.txt", "w") as f:
                f.write(str(self.std_probs[cls_pair]))

            self.figs[cls_pair][1].savefig(self.root_dir + f"/{cls_pair[0]}_{cls_pair[1]}_std_deviation_analysis.png")
            self.figs[cls_pair][0].savefig(self.root_dir + f"/{cls_pair[0]}_{cls_pair[1]}_probability_analysis.png")

        final_representation = {"Complexity": {str(cls_pair): f"{self.complexity[cls_pair]:4f}" for cls_pair in self.cls_pairs},
                              "p(c1)" : {str(cls_pair): f"{self.average_probs[cls_pair][0]:.4f} +-{self.std_probs[cls_pair][0]: .4f}" for cls_pair in self.cls_pairs},
                              "p(c2)" : {str(cls_pair): f"{self.average_probs[cls_pair][1]:.4f} +-{self.std_probs[cls_pair][1]: .4f}" for cls_pair in self.cls_pairs},
                                }

        with open(self.root_dir + "/final_presentation.json", "w") as f:
            json.dump(final_representation, f)

        logger.info("Saved boundary complexity results into " + self.root_dir)


class ReproduceTable2PerDatasetExperiment(AbstractExperiment):
    def __init__(self,
                 dataset: BaseGraphDataset,
                 num_runs: int,
                 strategy: Literal["dynamic_boundary", "cross_entropy"],
                 num_iterations: int,
                 temperature: float,
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 cls_pairs: Tuple[int, int] = None,
                 lr: float = 1,
                 target_size = 30,
                 w_budget_init: float = None,
                 w_budget_inc: float = None,
                 w_budget_dec: float = None,
                 model_kwargs = None,
                 ckpt_path = None,
                 model_architecture = None,
                 max_nodes: int = 25,
                 random_id: int = None,
                 save_dir: str = None,
                 target_probs: Tuple[float] = (0.45, 0.55)):
        super().__init__()
        logger.info("Setting up table 2 reproduction experiment...")
        self.cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]
        self.dataset = dataset
        self.num_runs = num_runs
        self.num_iterations = num_iterations
        self.save_dir = save_dir

        if cls_pairs is None:
            cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]

        self.cls_pairs = cls_pairs

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numruns={num_runs}_num_iteration={num_iterations}"
                              f"_{random_id}")

        self.root_dir = f"{self.root_dir}/reproduce_table_2/{dataset.name}/{strategy}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = {cls_pair: GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat)
            for cls_pair in self.cls_pairs}

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        default_crit_fn = get_default_criteria_entropy if strategy == "cross_entropy" else get_default_criteria_dynamic_boundary
        criterion = {cls_pair: default_crit_fn(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
                     for cls_pair in self.cls_pairs}

        trainer = {cls_pair: Trainer(
            sampler=sampler[cls_pair],
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(sampler[cls_pair].parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1))
            for cls_pair in self.cls_pairs}

        self.training_params = {cls_pair: get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])
        if training_params is None else training_params for cls_pair in self.cls_pairs}

        for cls_pair in self.cls_pairs:
            self.training_params[cls_pair].iterations = num_iterations if num_iterations is not None else self.training_params[cls_pair].iterations
            self.training_params[cls_pair].target_size = target_size if target_size is not None else self.training_params[cls_pair].target_size
            self.training_params[cls_pair].w_budget_init = w_budget_init if w_budget_init is not None else self.training_params[cls_pair].w_budget_init
            self.training_params[cls_pair].w_budget_inc = w_budget_inc if w_budget_inc is not None else self.training_params[cls_pair].w_budget_inc
            self.training_params[cls_pair].w_budget_dec = w_budget_dec if w_budget_dec is not None else self.training_params[cls_pair].w_budget_dec
            self.training_params[cls_pair].target_probs = {cls_pair[0]: target_probs, cls_pair[1]: target_probs}

        self.generator = {cls_pair: GraphGenerator(sampler[cls_pair], dataset, trainer[cls_pair], model)
                          for cls_pair in self.cls_pairs}

        self.success_counts = {}
        self.avg_iterations = {}
        logger.info(f"Finished setting up table 2 reproduction experiment for dataset {dataset.name}. ")

    def run(self):
        logger.info("Starting table 2 reproduction run...")
        for num, cls_pair in enumerate(self.cls_pairs):
            logger.info(f"Running class pair {cls_pair}...")
            training_params = {k: v for k,v in self.training_params[cls_pair].__dict__.items() if k != "eval_threshold"}
            self.success_counts[cls_pair], num_convergence_iterations = self.generator[cls_pair].success_counter(self.num_runs,
                                                                                     cls=cls_pair,
                                                                                     save_dir=self.save_dir,
                                                                                     **training_params)
            self.avg_iterations[cls_pair] = np.mean(num_convergence_iterations)
            logger.info(f"Class pair {cls_pair} finished. Success rate: {self.success_counts[cls_pair] / self.num_runs:.4f}. Average iterations: {self.avg_iterations[cls_pair]:.4f}")
            logger.info(f"Finished class pair: {num + 1}/{len(self.cls_pairs)}")
        logger.info("Finished table 2 reproduction run...")

    def plot(self):
        pass

    def save(self):
        final_representation = {"Success Rate": {str(cls_pair): f"{self.success_counts[cls_pair] / self.num_runs:.4f}" for cls_pair in self.cls_pairs},
            "Average Iterations": {str(cls_pair): f"{self.avg_iterations[cls_pair]:.4f}" for cls_pair in self.cls_pairs}}

        with open(self.root_dir + "/final_presentation.json", "w") as f:
            json.dump(final_representation, f)


class ReproduceFigure3PerDatasetExperiment(AbstractExperiment):
    """
    Reproduces Figure 3 of the paper. The confusion matrix, boundary margin and thickness for each adjacent
    class pair in the dataset are computed. The experiment uses saved boundary graphs and interpreters.
    """
    def __init__(self,
                 dataset: BaseGraphDataset,
                 num_graphs: int,
                 temperature: float,
                 graph_directory: str,
                 interpreter_directoy: str = None,
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 random_id: int = None,
                 ):
        super().__init__()
        logger.info("Setting up figure 3 reproduction experiment...")
        self.cls_pairs_base = DATASET_TO_CLS_PAIRS[dataset.name]
        self.cls_pairs = copy.deepcopy(self.cls_pairs_base)
        self.cls_pairs.extend([(x[1], x[0]) for x in self.cls_pairs_base])
        self.dataset = dataset
        self.num_graphs = num_graphs

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"{'interpreter' if interpreter_directoy is not None else 'dataset_sampling'}_{random_id}")

        self.root_dir = f"{self.root_dir}/reproduce_figure_3/{dataset.name}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = {cls_pair : GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat)
            for cls_pair in self.cls_pairs }

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        self.model = model

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = {cls_pair: criterion if criterion is not None
        else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
                     for cls_pair in self.cls_pairs}

        trainer = {cls_pair: Trainer(
            sampler=sampler[cls_pair],
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(sampler[cls_pair].parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1))
            for cls_pair in self.cls_pairs}


        graph_directory = {cls_pair: graph_directory + f"/{cls_pair[0]}-{cls_pair[1]}" for cls_pair in self.cls_pairs_base}
        self.interpreter_directoy = {cls: interpreter_directoy + f"/{cls}" for cls in dataset.GRAPH_CLS} if interpreter_directoy is not None else interpreter_directoy

        self.training_params = {cls_pair: get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])
        if training_params is None else training_params for cls_pair in self.cls_pairs}

        self.evaluator = {}

        for cls_pair in self.cls_pairs:
            cls_label_1, cls_label_2 = cls_pair
            load_dir = graph_directory[cls_pair] if cls_pair in graph_directory.keys() else graph_directory[tuple(reversed(cls_pair))]
            self.evaluator[cls_pair] = BoundaryEvaluator(sampler[cls_pair],
                                                         model,
                                                         dataset,
                                                         trainer[cls_pair],
                                                         cls_pair=cls_pair,
                                                         logger=logger,
                                                         load_dir=load_dir,
                                                         interpreter_load_dir=self.interpreter_directoy[cls_label_1] if interpreter_directoy is not None else None,
                                                         temperature=temperature,
                                                         max_nodes=max_nodes,
                                                         order_of_pairs=(cls_label_1, cls_label_2))


        self.margins = {cls_pair: None for cls_pair in self.cls_pairs}
        self.thickness = {cls_pair: None for cls_pair in self.cls_pairs}
        self.confusion_matrix = None
        self.figs = {cls_pair: [] for cls_pair in self.cls_pairs}
        logger.info(f"Finished setting up figure 3 reproduction experiment for dataset {dataset.name}. "
                    f"Using saved boundary and interpreter graphs from directory")

    def run(self):
        logger.info("Starting figure 3 reproduction run...")
        for cls_pair in self.cls_pairs:
            self.margins[cls_pair] = self.evaluator[cls_pair].boundary_margin(num_graphs=self.num_graphs,)
            self.thickness[cls_pair] = self.evaluator[cls_pair].boundary_thickness(num_graphs=self.num_graphs,)

        self.confusion_matrix = self.dataset.model_evaluate(self.model)["cm"]
        logger.info("Finished figure 3 reproduction run...")

    def plot(self):
        class_num=len(self.dataset.GRAPH_CLS)
        self.matrix_margins = np.full((class_num, class_num), np.nan)
        self.matrix_thickness = np.full((class_num, class_num), np.nan)
        for cls_pair in self.cls_pairs:
            self.matrix_margins[cls_pair[0], cls_pair[1]] = self.margins[cls_pair]
            self.matrix_thickness[cls_pair[0], cls_pair[1]] = self.thickness[cls_pair]
        fig_margins = draw_matrix_adj(self.matrix_margins, self.dataset.GRAPH_CLS.values(), fmt='.2f',
                                      return_fig=True, xlabel="Decision Boundary", ylabel="Decision Region",
                                      title="Boundary Margin")
        fig_thickness = draw_matrix_adj(self.matrix_thickness, self.dataset.GRAPH_CLS.values(), fmt='.2f',
                                        return_fig=True, xlabel="Decision Boundary", ylabel="Decision Region",
                                        title="Boundary Thickness")
        self.figs["margins_matrix"] = fig_margins
        self.figs["thickness_matrix"] = fig_thickness
        fig_confusion = draw_matrix_adj(self.confusion_matrix, self.dataset.GRAPH_CLS.values(), fmt='d', return_fig=True,
                                        xlabel="Predicted Label", ylabel="True Label", title="Confusion Matrix")
        self.figs["confusion_matrix"] = fig_confusion
    def save(self):
        final_representation = {"Boundary thickness": {str(cls_pair): f"{self.thickness[cls_pair]:4f}" for cls_pair in self.cls_pairs},
                               "Boundary margin": {str(cls_pair): f"{self.margins[cls_pair]:4f}" for cls_pair in self.cls_pairs},
                                }

        with open(self.root_dir + "/final_presentation.json", "w") as f:
            json.dump(final_representation, f)

        np.savetxt(self.root_dir + "/confusion_matrix.csv", self.confusion_matrix, delimiter=",", fmt='%d')
        self.figs["margins_matrix"].savefig(self.root_dir + "/heatmaps_margins.png")
        self.figs["thickness_matrix"].savefig(self.root_dir + "/heatmaps_thickness.png")
        self.figs["confusion_matrix"].savefig(self.root_dir + "/confusion_matrix.png")
        logger.info("Saved boundary thickness and margin results into " + self.root_dir)


class BoundaryComplexityPerTargetRangeExperiment(AbstractExperiment):
    """
    This experiment runs a boundary complexity experiment for a single class pair and a target range of probabilities.
    Intuitively, the boundary complexity should decrease as the range of probabilities decreases, because smaller ranges
    correspond to better boundary graph approximations.
    """
    def __init__(self,
                 dataset: BaseGraphDataset,
                 cls_pair: Tuple[int, int],
                 num_graphs: int,
                 temperature: float,
                 ranges: List[Tuple[float, float]],
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 graph_directory: str = None,
                 random_id: int = None):

        super().__init__()
        logger.info(f"Setting up boundary complexity experiment for target ranges: {str(ranges)}...")
        self.dataset = dataset
        self.num_graphs = num_graphs

        if training_params is None:
            training_params = get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])

        self.training_params = training_params

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"target-ranges={str(ranges)}_{random_id}")

        self.root_dir = f"{self.root_dir}/boundary_complexity_target_ranges/{dataset.name}/{cls_pair[0]}_{cls_pair[1]}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat
        )

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = criterion if criterion is not None else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
        trainer = Trainer(
            sampler=sampler,
            discriminator=model,
            criterion=criterion,
            optimizer=(o := torch.optim.SGD(sampler.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
        )

        graph_directory = BOUNDARY_GRAPHS_DIRs_TO_DATASET[dataset.name][cls_pair] if graph_directory is None else graph_directory
        self.evaluator = BoundaryEvaluator(sampler,
                                           model,
                                           dataset,
                                           trainer,
                                           cls_pair=cls_pair,
                                           logger=logger,
                                           load_dir=graph_directory,
                                           temperature=temperature,
                                           max_nodes=max_nodes)
        self.ranges = ranges
        self.complexity = None
        self.figs = []
        logger.info("Finished setting up boundary complexity experiment. Using saved boundary graphs from directory: " + graph_directory)

    def run(self):
        logger.info("Starting boundary complexity run for ranges...")
        self.complexity = self.evaluator.boundary_complexity_for_probability_range(ranges=self.ranges,
                                                                                   num_graphs=self.num_graphs,
                                                                                   **self.training_params.__dict__)
        logger.info("Finished boundary complexity run for ranges...")
        return self.complexity

    def plot(self):
        labels = [f"{k[0]} - {k[1]}" for k in  self.complexity.keys()]
        values = list(self.complexity.values())

        plt.figure(figsize=(18, 14))
        plt.bar(labels, values, color='blue', alpha=0.7)
        plt.xlabel("Range")
        plt.ylabel("Boundary Complexity")
        plt.title(f"Dataset {self.dataset.name} and class pair {self.evaluator.cls_pair}")
        plt.xticks(rotation=45)
        plt.grid(axis='y', linestyle='--', alpha=0.6)

        fig = plt.gcf()

        self.figs.append(fig)

    def save(self):
        final_representation = {str(k): f"{v:4f}" for k,v in self.complexity.items()}
        with open(self.root_dir + "/" +"complexity.json", "w") as f:
            json.dump(final_representation, f)

        self.figs[0].savefig(self.root_dir + "/boundary_complexity.png")
        logger.info("Saved boundary complexity results into " + self.root_dir)

class BoundaryMarginPerTargetRangeExperiment(AbstractExperiment):
    """
    This experiment runs a boundary margin experiment for a single class pair and a target range of probabilities. The goal
    is to analyze how the margin changes as the probability of the target class changes. Intuitively, the margin should
    increase as the range of probabilities decreases, because smaller ranges correspond to better boundary graph
    approximations.
    """
    def __init__(self,
                 dataset: BaseGraphDataset,
                 cls_pair: Tuple[int, int],
                 num_graphs: int,
                 reference_class: int,
                 temperature: float,
                 ranges: List[Tuple[float, float]],
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 graph_directory: str = None,
                 interpreter_load_dir=None,
                 random_id: int = None):

        super().__init__()
        assert reference_class in cls_pair, "Reference class must be in the class pair."
        logger.info(f"Setting up boundary margin experiment for target ranges: {str(ranges)}...")
        self.dataset = dataset
        self.num_graphs = num_graphs

        if training_params is None:
            training_params = get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])

        self.training_params = training_params

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"target-ranges={str(ranges)}_{random_id}")

        self.root_dir = f"{self.root_dir}/boundary_margin_target_ranges/{dataset.name}/{cls_pair[0]}_{cls_pair[1]}/{reference_class}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat
        )

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = criterion if criterion is not None else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
        trainer = Trainer(
            sampler=sampler,
            discriminator=model,
            criterion=criterion,
            optimizer=(o := torch.optim.SGD(sampler.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
        )

        graph_directory = BOUNDARY_GRAPHS_DIRs_TO_DATASET[dataset.name][cls_pair] if graph_directory is None else graph_directory
        self.evaluator = BoundaryEvaluator(sampler,
                                           model,
                                           dataset,
                                           trainer,
                                           cls_pair=cls_pair,
                                           logger=logger,
                                           load_dir=graph_directory,
                                           temperature=temperature,
                                           max_nodes=max_nodes,
                                           interpreter_load_dir= interpreter_load_dir,
                                           order_of_pairs=(reference_class, cls_pair[0] if reference_class == cls_pair[1] else cls_pair[1]))
        self.ranges = ranges
        self.margin = None
        self.figs = []
        logger.info("Finished setting up boundary margin experiment. Using saved boundary graphs from directory: " + graph_directory)

    def run(self):
        logger.info("Starting boundary margin run for ranges...")
        self.margin = self.evaluator.boundary_margin_for_probability_range(ranges=self.ranges,
                                                                               num_graphs=self.num_graphs,
                                                                               **self.training_params.__dict__)
        logger.info("Finished boundary margin run for ranges...")
        return self.margin

    def plot(self):
        labels = [f"{k[0]} - {k[1]}" for k in  self.margin.keys()]
        values = list(self.margin.values())

        plt.figure(figsize=(18, 14))
        plt.bar(labels, values, color='blue', alpha=0.7)
        plt.xlabel("Range")
        plt.ylabel("Boundary Margin")
        plt.title(f"Dataset {self.dataset.name} and class pair {self.evaluator.cls_pair}")
        plt.xticks(rotation=45)
        plt.grid(axis='y', linestyle='--', alpha=0.6)

        fig = plt.gcf()

        self.figs.append(fig)

    def save(self):
        final_representation = {str(k): f"{v:4f}" for k,v in self.margin.items()}
        with open(self.root_dir + "/" +"margin.json", "w") as f:
            json.dump(final_representation, f)

        self.figs[0].savefig(self.root_dir + "/boundary_margin.png")
        logger.info("Saved boundary margin results into " + self.root_dir)


class BoundaryThicknessPerTargetRangeExperiment(AbstractExperiment):
    """
    This experiment runs a boundary thickness experiment for a single class pair and a target range of probabilities. The goal
    is to analyze how the thickness changes as the probability of the target class changes. Intuitively, the thickness should
    increase as the range of probabilities decreases, because smaller ranges correspond to better boundary graph
    approximations.
    """
    def __init__(self,
                 dataset: BaseGraphDataset,
                 cls_pair: Tuple[int, int],
                 num_graphs: int,
                 reference_class: int,
                 temperature: float,
                 ranges: List[Tuple[float, float]],
                 training_params: TrainingParams = None,
                 learn_node_feat: bool = True,
                 lr: float = 1,
                 model_kwargs=None,
                 ckpt_path=None,
                 model_architecture=None,
                 max_nodes: int = 25,
                 criterion: WeightedCriterion = None,
                 graph_directory: str = None,
                 interpreter_load_dir = None,
                 random_id: int = None):

        super().__init__()
        assert reference_class in cls_pair, "Reference class must be in the class pair."
        logger.info(f"Setting up boundary thickness experiment for target ranges: {str(ranges)}...")
        self.dataset = dataset
        self.num_graphs = num_graphs

        if training_params is None:
            training_params = get_default_training_params(dataset.name, cls_pair[0], cls_pair[1])

        self.training_params = training_params

        random_id = np.random.randint(0, 1000000) if random_id is None else random_id
        self.experiment_id = (f"temperature={temperature}_lr={lr}_numgraphs={num_graphs}_maxnodes={max_nodes}"
                              f"target-ranges={str(ranges)}_{random_id}")

        self.root_dir = f"{self.root_dir}/boundary_thickness_target_ranges/{dataset.name}/{cls_pair[0]}_{cls_pair[1]}//{reference_class}/{self.experiment_id}"

        if not os.path.exists(self.root_dir):
            os.makedirs(self.root_dir)

        if model_kwargs is None:
            model_kwargs = get_model_kwargs(dataset, dataset.name)

        if ckpt_path is None:
            ckpt_path = CKPT_PATHS[dataset.name]

        sampler = GraphSampler(
            max_nodes=max_nodes,
            temperature=temperature,
            num_node_cls=len(dataset.NODE_CLS),
            learn_node_feat=learn_node_feat
        )

        model_arch = GCNClassifier if model_architecture is None else model_architecture
        model = model_arch(node_features=len(dataset.NODE_CLS),
                           num_classes=len(dataset.GRAPH_CLS),
                           hidden_channels=model_kwargs["hidden_channels"],
                           num_layers=model_kwargs["num_layers"])
        model.load_state_dict(torch.load(ckpt_path))

        dataset_list_gt = self.dataset.split_by_class()
        mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        criterion = criterion if criterion is not None else get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
        trainer = Trainer(
            sampler=sampler,
            discriminator=model,
            criterion=criterion,
            optimizer=(o := torch.optim.SGD(sampler.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
        )

        graph_directory = BOUNDARY_GRAPHS_DIRs_TO_DATASET[dataset.name][cls_pair] if graph_directory is None else graph_directory
        self.evaluator = BoundaryEvaluator(sampler,
                                           model,
                                           dataset,
                                           trainer,
                                           cls_pair=cls_pair,
                                           logger=logger,
                                           load_dir=graph_directory,
                                           temperature=temperature,
                                           max_nodes=max_nodes,
                                           interpreter_load_dir=interpreter_load_dir,
                                           order_of_pairs=(reference_class, cls_pair[0] if reference_class == cls_pair[1] else cls_pair[1]))
        self.ranges = ranges
        self.thickness = None
        self.figs = []
        logger.info("Finished setting up boundary thickness experiment. Using saved boundary graphs from directory: " + graph_directory)

    def run(self):
        logger.info("Starting boundary thickness run for ranges...")
        self.thickness = self.evaluator.boundary_thickness_for_probability_range(ranges=self.ranges,
                                                                                 num_graphs=self.num_graphs,
                                                                                 **self.training_params.__dict__)
        logger.info("Finished boundary thickness run for ranges...")
        return self.thickness

    def plot(self):
        labels = [f"{k[0]} - {k[1]}" for k in  self.thickness.keys()]
        values = list(self.thickness.values())

        plt.figure(figsize=(18, 14))
        plt.bar(labels, values, color='blue', alpha=0.7)
        plt.xlabel("Range")
        plt.ylabel("Boundary Thickness")
        plt.title(f"Dataset {self.dataset.name} and class pair {self.evaluator.cls_pair}")
        plt.xticks(rotation=45)
        plt.grid(axis='y', linestyle='--', alpha=0.6)

        fig = plt.gcf()

        self.figs.append(fig)

    def save(self):
        final_representation = {str(k): f"{v:4f}" for k,v in self.thickness.items()}
        with open(self.root_dir + "/" +"thickness.json", "w") as f:
            json.dump(final_representation, f)

        self.figs[0].savefig(self.root_dir + "/boundary_thickness.png")
        logger.info("Saved boundary thickness results into " + self.root_dir)

class ParallelExperimentRunner:
    def __init__(self, experiments: List[Type[AbstractExperiment]]):
        """
        Initialize with a list of experiment instances.
        """
        self.experiments = experiments

        for exp in experiments:
            if not isinstance(exp, type(experiments[0])):
                raise ValueError("All experiments must be of the same type. Otherwise we will get race conditions.")


    def run_all(self, max_workers: int = None):
        """
        Run all experiments in parallel.
        """
        results = {}
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_exp = {
                executor.submit(self._run_experiment, exp): exp for exp in self.experiments
            }
            for future in concurrent.futures.as_completed(future_to_exp):
                exp = future_to_exp[future]
                try:
                    result = future.result()
                    results[exp] = result
                except Exception as exc:
                    logger.error(f"Experiment {exp} generated an exception: {exc}")
        return results

    @staticmethod
    def _run_experiment(exp: AbstractExperiment):
        """
        Wrapper for running a single experiment.
        """
        logger.info(f"Starting experiment: {type(exp).__name__}")
        result = exp.run()
        exp.plot()
        exp.save()
        logger.info(f"Finished experiment: {type(exp).__name__}")
        return result


if __name__ == "__main__":
    import warnings
    warnings.filterwarnings("ignore")
    dataset = MotifDataset(seed=12345)
    exp = ReproduceAdjacencyExperiment(dataset)
    #exp = AdjacencyAndBoundaryGraphGenerationExperiment(dataset,  temperature=0.2, run_new_adjacency_analysis=True, num_graphs=3)
    #exp = ReproduceBoundaryComplexityExperimentOnePair(dataset, (1, 2), 100, temperature=0.2, graph_directory="/home/lukas/Projects/FACT_cp/scripts/graphs/boundary/ENZYMES/1-2")
    #exp = ReproduceTable1PerDatasetExperiment(dataset, 500, temperature=0.2)
    #exp = ReproduceBoundaryComplexityExperiment(dataset, (0,1), 500, temperature=0.2, num_gen_iterations=2000)
    #exp = ReproduceGraphGenerationExperiment(dataset, (0, 1))
    #exp = ReproduceTable2PerDatasetExperiment(dataset, 10, num_iterations=100, strategy="cross_entropy",temperature=0.2)
    #exp = ReproduceFigure3PerDatasetExperiment(dataset, 100, temperature=0.2,
    #                                           graph_directory="/home/lukas/Projects/FACT_cp/scripts/graphs/boundary/COLLAB",
    #                                           interpreter_directoy="/home/lukas/Projects/FACT_cp/scripts/graphs/interpreter/COLLAB")

    #exp = BoundaryComplexityPerTargetRangeExperiment(dataset, (0, 1), 300, temperature=0.2,
    #                                                 ranges=[(0.45, 0.55), (0.47, 0.53), (0.48, 0.52), (0.49, 0.51), (0.495, 0.505)])
    exp.run()
    exp.plot()
    exp.save()