import argparse

from gnnboundary.utils.boundary_generator import GraphGenerator, GraphRetrainer
from scripts.sampling_training import *
from scripts.experiments import get_model_kwargs, CKPT_PATHS
from scripts.experiments_logging import initialize_logger
from multiprocessing import Pool, cpu_count
import logging
from gnn_xai_common import Trainer, GraphSampler
from typing import Literal

logger = initialize_logger("graph_saver.log", logging.INFO)


def save_graphs(dataset_name: Literal["motif", "collab", "enzymes","reddit","imdb"],
                num_graphs: int,
                num_runs: int,
                num_iterations: int = 500,
                criterion: dict = None,
                model_arch: object = GCNClassifier,
                model_kwargs = None,
                ckpt_path: str = None,
                cls_pairs: List[Tuple[int, int]] = None,
                max_nodes: int = 25,
                temperature: float = 0.15,
                learn_node_feat: bool = True,
                lr: float = 1,
                target_size: int = 30, #60 for motif
                target_probs: Tuple[float, float] = (0.4, 0.6),
                k_samples = 32,
                num_workers: int = None
                ):
    """
    Save graphs using different datasets and configurations.
    """
    match dataset_name:
        case "motif":
            dataset = MotifDataset(seed=12345)
        case "collab":
            dataset = CollabDataset(seed=12345)
        case "enzymes":
            dataset = ENZYMESDataset(seed=12345)
        case "reddit":
            dataset = RedditDataset(seed=12345)
        case "imdb":
            dataset = IMDBDataset(seed=12345)
        case _:
            raise ValueError("Invalid dataset name")

    if cls_pairs is None:
        cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]

    if model_kwargs is None:
        model_kwargs = get_model_kwargs(dataset, dataset.name)
    model = model_arch(**model_kwargs)

    if ckpt_path is None:
        ckpt_path = CKPT_PATHS[dataset.name]

    model.load_state_dict(torch.load(ckpt_path))
    dataset_list_gt = dataset.split_by_class()
    mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

    if criterion is None:
        criterion = {cls_pair: get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
                     for cls_pair in cls_pairs}
    elif isinstance(criterion, WeightedCriterion):
        criterion = {cls_pair: criterion for cls_pair in cls_pairs}


    s = GraphSampler(
        max_nodes=max_nodes,
        temperature=temperature,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=learn_node_feat
    )

    trainer = {}
    for cls_pair in cls_pairs:
        trainer[cls_pair] = Trainer(
            sampler=s,
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(s.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
        )

    graph_generators = {cls_pair: GraphGenerator(trainer[cls_pair].sampler, dataset, trainer[cls_pair], model)
                       for cls_pair in cls_pairs}



    def generate_graphs_for_cls_pair(cls_pair, graph_generators, iterations, num_runs,
                                     num_graphs, logger, target_size, target_probs, k_samples):
        """ Helper function for multiprocessing execution """
        logger.info(f"Generating boundary for classes {cls_pair}")
        target_probs = {cls_pair[0] : target_probs, cls_pair[1]: target_probs}
        graph_generators[cls_pair](num_runs, num_graphs, logger,
                                   strategy="boundary",
                                   save_graphs=True, cls=cls_pair,
                                   target_size=target_size,
                                   iterations=iterations,
                                   target_probs=target_probs,
                                   w_budget_init=1,
                                   w_budget_inc=1.1,
                                   w_budget_dec=0.95,
                                   k_samples=k_samples)
        logger.info(f"Finished generating for classes {cls_pair}")

    logger.info("Starting to generate graphs")
    if num_workers is None:
        for cls_pair in cls_pairs:
            logger.info(f"Generating boundary for classes {cls_pair}")
            generate_graphs_for_cls_pair(cls_pair, graph_generators, num_iterations, num_runs, num_graphs, logger,
                                         target_size, target_probs, k_samples)
            logger.info(f"Finished generating for classes {cls_pair}")
    else:
        num_workers = min(len(cls_pairs), cpu_count()) if num_workers is None else num_workers

        task_args = [(cls_pair, graph_generators, num_runs, num_graphs, logger,
                      target_size, target_probs, k_samples) for cls_pair in cls_pairs]

        with Pool(processes=num_workers) as pool:
            pool.starmap(generate_graphs_for_cls_pair, task_args)

    logger.info("Finished generating graphs")


def save_graphs_interpreter(
                dataset_name: Literal["motif", "collab", "enzymes","reddit","imdb"],
                num_graphs: int,
                num_runs: int,
                num_iterations: int = 500,
                criterion: dict = None,
                model_arch: object = GCNClassifier,
                model_kwargs = None,
                ckpt_path: str = None,
                cls: Tuple[int] = None,
                max_nodes: int = 25,
                temperature: float = 0.15,
                learn_node_feat: bool = True,
                lr: float = 1,
                target_size: int = 30, #60 for motif
                target_probs: Tuple[float, float] = (0.4, 0.6),
                k_samples = 32,
                num_workers: int = None):
    """
    Save graphs using different datasets and configurations.
    """
    match dataset_name:
        case "motif":
            dataset = MotifDataset(seed=12345)
        case "collab":
            dataset = CollabDataset(seed=12345)
        case "enzymes":
            dataset = ENZYMESDataset(seed=12345)
        case "reddit":
            dataset = RedditDataset(seed=12345)
        case "imdb":
            dataset = IMDBDataset(seed=12345)
        case _:
            raise ValueError("Invalid dataset name")

    if model_kwargs is None:
        model_kwargs = get_model_kwargs(dataset, dataset.name)
    model = model_arch(**model_kwargs)

    if ckpt_path is None:
        ckpt_path = CKPT_PATHS[dataset.name]

    model.load_state_dict(torch.load(ckpt_path))
    dataset_list_gt = dataset.split_by_class()
    mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

    if criterion is None:
        criterion = {
            c: get_default_criteria_interpreter(dataset.name, mean_embeds, c)
            for c in cls}
    elif isinstance(criterion, WeightedCriterion):
        criterion = {c: criterion for c in cls}

    s = GraphSampler(
        max_nodes=max_nodes,
        temperature=temperature,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=learn_node_feat
    )

    trainer = {}
    for c in cls:
        trainer[c] = Trainer(
            sampler=s,
            discriminator=model,
            criterion=criterion[c],
            optimizer=(o := torch.optim.SGD(s.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
        )

    graph_generators = {c: GraphGenerator(trainer[c].sampler, dataset, trainer[c], model)
                        for c in cls}

    def generate_graphs_for_cls(c, graph_generators, iterations, num_runs,
                                     num_graphs, logger, target_size, target_probs, k_samples):
        """ Helper function for multiprocessing execution """
        logger.info(f"Generating boundary for classes {c}")
        target_probs = {c: target_probs}
        graph_generators[c](num_runs, num_graphs, logger,
                                   strategy="interpreter",
                                   save_graphs=True, cls=c,
                                   target_size=target_size,
                                   iterations=iterations,
                                   target_probs=target_probs,
                                   w_budget_init=1,
                                   w_budget_inc=1.1,
                                   w_budget_dec=0.95,
                                   k_samples=k_samples)
        logger.info(f"Finished generating for classes {c}")

    logger.info("Starting to generate graphs")
    if num_workers is None:
        for c in cls:
            logger.info(f"Generating boundary for classes {c}")
            generate_graphs_for_cls(c, graph_generators, num_iterations, num_runs, num_graphs, logger,
                                         target_size, target_probs, k_samples)
            logger.info(f"Finished generating for classes {c}")
    else:
        num_workers = min(len(cls), cpu_count()) if num_workers is None else num_workers

        task_args = [(c, graph_generators, num_runs, num_graphs, logger,
                      target_size, target_probs, k_samples) for c in cls]

        with Pool(processes=num_workers) as pool:
            pool.starmap(generate_graphs_for_cls, task_args)

    logger.info("Finished generating graphs")

def retrain_graph(dataset_name: Literal["motif", "collab", "enzymes", "reddit","imdb"],
                  graph_dir: str,
                  num_runs: int,
                  num_graphs: int,
                  num_iterations: int = 500,
                  model_arch: object = GCNClassifier,
                  model_kwargs=None,
                  ckpt_path: str = None,
                  cls_pairs: List[Tuple[int, int]] = None,
                  criterion: dict = None,
                  max_nodes: int = 25,
                  temperature: float = 0.15,
                  learn_node_feat: bool = True,
                  lr: float = 1,
                  target_size: int = 30,  # 60 for motif
                  target_probs: Tuple[float, float] = (0.45, 0.55),
                  k_samples=32,
                  ):
    match dataset_name:
        case "motif":
            dataset = MotifDataset(seed=12345)
        case "collab":
            dataset = CollabDataset(seed=12345)
        case "enzymes":
            dataset = ENZYMESDataset(seed=12345)
        case "reddit":
            dataset = RedditDataset(seed=12345)
        case "imdb":
            dataset = IMDBDataset(seed=12345)
        case _:
            raise ValueError("Invalid dataset name")

    if cls_pairs is None:
        cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]

    if model_kwargs is None:
        model_kwargs = get_model_kwargs(dataset, dataset.name)
    model = model_arch(**model_kwargs)

    if ckpt_path is None:
        ckpt_path = CKPT_PATHS[dataset.name]

    model.load_state_dict(torch.load(ckpt_path))
    dataset_list_gt = dataset.split_by_class()
    mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

    if criterion is None:
        criterion = {
            cls_pair: get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
            for cls_pair in cls_pairs}
    elif isinstance(criterion, WeightedCriterion):
        criterion = {cls_pair: criterion for cls_pair in cls_pairs}

    s = GraphSampler(
        max_nodes=max_nodes,
        temperature=temperature,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=learn_node_feat
    )

    trainer = {}
    for cls_pair in cls_pairs:
        trainer[cls_pair] = Trainer(
            sampler=s,
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(s.parameters(), lr=lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
        )

    graph_retrainers = {cls_pair: GraphRetrainer(sampler_dir=graph_dir,
                                                 num_graphs=num_graphs,
                                                 trainer=trainer[cls_pair],
                                                 dataset=dataset,
                                                 model=model)}

    logger.info("Starting to retrain graphs")
    for cls_pair in cls_pairs:
        logger.info(f"Retraining boundary for classes {cls_pair}")
        graph_retrainers[cls_pair](num_runs=num_runs,
                                   target_size=target_size,
                                   iterations=num_iterations,
                                   target_probs={cls_pair[0]: target_probs, cls_pair[1]: target_probs},
                                   w_budget_init=1,
                                   w_budget_inc=1.1,
                                   w_budget_dec=0.95,
                                   k_samples=k_samples)
        logger.info(f"Finished retraining for classes {cls_pair}")


def main():
    parser = argparse.ArgumentParser(description="Save graphs using different datasets and configurations.")
    parser.add_argument("goal", type=str, default="generate", choices=["retrain", "generate", "interpreter"], help="Goal of the script")
    parser.add_argument("dataset_name", type=str, choices=["motif", "collab", "enzymes","reddit","imdb"], help="Dataset name")
    parser.add_argument("num_graphs", type=int, help="Number of graphs to generate")
    parser.add_argument("num_runs", type=int, help="Number of runs")
    parser.add_argument("--num_iterations", type=int, default=500, help="Number of iterations")
    parser.add_argument("--graph_dir", type=str, default=None, help="Directory of graphs to retrain")
    parser.add_argument("--max_nodes", type=int, default=25, help="Maximum number of nodes")
    parser.add_argument("--temperature", type=float, default=0.15, help="Sampling temperature")
    parser.add_argument("--learn_node_feat", type=bool, default=True, help="Whether to learn node features")
    parser.add_argument("--lr", type=float, default=1, help="Learning rate")
    parser.add_argument("--target_size", type=int, default=40, help="Target size of generated graphs")
    parser.add_argument("--k_samples", type=int, default=32, help="Number of k-samples")
    parser.add_argument("--num_workers", type=int, default=None, help="Number of workers for multiprocessing")
    parser.add_argument("--cls_pairs", type=str, default=None, help="List of class pairs as '1,2;3,4;5,6'")
    parser.add_argument("--target_probs", type=str, default="0.45,0.55", help="Target probabilities for classes")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Checkpoint path of model")
    parser.add_argument("--cls", type=str, default=None, help="Class for interpreter")
    args = parser.parse_args()

    cls_pairs = [tuple(map(lambda x: int(x), x.split(","))) for x in args.cls_pairs.split(";")] if args.cls_pairs is not None else None
    target_probs = tuple(map(lambda x: float(x), args.target_probs.split(",")))

    if args.goal == "retrain":
        assert args.graph_dir is not None, "Graph directory must be provided for retraining"
        retrain_graph(
            dataset_name=args.dataset_name,
            graph_dir=args.graph_dir,
            num_runs=args.num_runs,
            num_graphs=args.num_graphs,
            num_iterations=args.num_iterations,
            model_arch=GCNClassifier,
            ckpt_path=args.ckpt_path,
            cls_pairs=cls_pairs,
            max_nodes=args.max_nodes,
            temperature=args.temperature,
            learn_node_feat=args.learn_node_feat,
            lr=args.lr,
            target_size=args.target_size,
            k_samples=args.k_samples,
            target_probs=target_probs
        )

    elif args.goal == "interpreter":
        assert args.cls is not None, "Class must be provided for interpreter"
        cls = list(map(lambda x: int(x), args.cls.split(",")))
        save_graphs_interpreter(
            dataset_name=args.dataset_name,
            num_graphs=args.num_graphs,
            num_runs=args.num_runs,
            num_iterations=args.num_iterations,
            model_arch=GCNClassifier,
            ckpt_path=args.ckpt_path,
            cls=cls,
            max_nodes=args.max_nodes,
            temperature=args.temperature,
            learn_node_feat=args.learn_node_feat,
            lr=args.lr,
            target_size=args.target_size,
            k_samples=args.k_samples,
            target_probs=target_probs,
            num_workers=args.num_workers
        )

    else:
        save_graphs(
            dataset_name=args.dataset_name,
            num_graphs=args.num_graphs,
            num_runs=args.num_runs,
            cls_pairs=cls_pairs,
            num_iterations=args.num_iterations,
            max_nodes=args.max_nodes,
            temperature=args.temperature,
            learn_node_feat=args.learn_node_feat,
            lr=args.lr,
            target_size=args.target_size,
            k_samples=args.k_samples,
            num_workers=args.num_workers,
            target_probs=target_probs,
            ckpt_path=args.ckpt_path
        )


if __name__ == "__main__":
    main()

if __name__ == "__main__":
    import warnings
    warnings.filterwarnings("ignore")
    main()