import argparse
import logging
from scripts.default_configs import *
from scripts.experiments import (
    ReproduceAdjacencyExperiment,
    AdjacencyAndBoundaryGraphGenerationExperiment,
    ReproduceBoundaryComplexityExperimentOnePair,
    ReproduceTable1PerDatasetExperiment,
    ReproduceGraphGenerationExperiment,
    ReproduceTable2PerDatasetExperiment,
    BoundaryComplexityPerTargetRangeExperiment,
    BoundaryMarginPerTargetRangeExperiment,
    BoundaryThicknessPerTargetRangeExperiment,
    ReproduceFigure3PerDatasetExperiment
)

from scripts.embedding_space_check import run_embedding_trainer

from scripts.save_graphs import save_graphs

def get_experiment_parser():
    """
    Get argument parser for running experiments. It runs the following experiments:
    - adjacency_graph_gen: Generate graphs based on adjacency analysis.
    - boundary_complexity: Reproduce boundary complexity experiment for one class pair.
    - table1: Reproduce table 1 experiment.
    - graph_gen: Reproduce graph generation experiment.
    - table2: Reproduce table 2 experiment.
    - complexity_ranges: Reproduce boundary complexity experiment for multiple target probability ranges.
    - margin_ranges: Reproduce boundary margin experiment for multiple target probability ranges.
    - thickness_ranges: Reproduce boundary thickness experiment for multiple target probability ranges.
    """
    parser = argparse.ArgumentParser(description="Run different graph-based experiments.")
    parser.add_argument("--experiment", type=str, required=True, choices=[
        "adjacency", "adjacency_graph_gen", "boundary_complexity", "table1", "graph_gen", "table2", "figure3",
        "complexity_ranges", "margin_ranges",
        "thickness_ranges", "embedding_space_check"
    ], help="Select which experiment to run.")
    parser.add_argument("--dataset", type=str, required=True, help="Specify dataset name.", choices=[
        "collab", "enzymes", "motif", "IMDB"
    ])
    parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
    parser.add_argument("--num_graphs", type=int, default=10, help="Number of graphs to generate.")
    parser.add_argument("--num_runs", type=int, help="Number of runs for table 2.")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Checkpoint path for model.")
    parser.add_argument("--num_iterations", type=int, default=100, help="Number of iterations.")
    parser.add_argument("--strategy", type=str, choices=["dynamic_boundary", "cross_entropy"], default="cross_entropy", help="Training strategy.")
    parser.add_argument("--class_pair", type=str, default="0,1", help="Class pair for experiments (comma-separated).")
    parser.add_argument("--graph_directory", type=str, help="Directory to load saved graphs.")
    parser.add_argument("--save_dir", type=str, default="./results", help="Directory to save experiment results.")
    parser.add_argument("--lr", type=float, default=1.0, help="Learning rate.")
    parser.add_argument("--adj_threshold", type=float, default=0.8, help="Adjacency threshold for adjacency analysis.")
    parser.add_argument("--learn_node_feat", type=bool, default=True, help="Whether to learn node features.")
    parser.add_argument("--w_budget_init", type=float, help="Initial budget weight.")
    parser.add_argument("--w_budget_inc", type=float, help="Budget increment value.")
    parser.add_argument("--w_budget_dec", type=float, help="Budget decrement value.")
    parser.add_argument("--max_nodes", type=int, default=25, help="Maximum nodes for graph sampling.")
    parser.add_argument("--target_size", type=int, default=30, help="Target size for graph generation.")
    parser.add_argument("--random_id", type=int, help="Random ID for experiment tracking.")
    parser.add_argument("--ranges", type=str, help="Probability ranges for boundary complexity.",
                        default="0.45,0.55;0.47,0.53;0.48,0.52;0.49,0.51;0.495,0.505")
    parser.add_argument("--interpreter_directory", type=str, default=None, help="Interpreter directory for boundary analysis.")
    parser.add_argument("--reference_class", type=int, help="Reference class for boundary margin and thickness", default=None)
    parser.add_argument("--target_probs", default="0.45,0.55", type=str, help="Target probability range for boundary margin and thickness",)
    parser.add_argument("--num_embeddings", type=int, help="Number of embeddings to generate for embedding space check.", default=100)
    return parser

def main():
    parser = get_experiment_parser()
    args = parser.parse_args()

    match args.dataset:
        case "collab":
            dataset = CollabDataset(seed=12345)
        case "enzymes":
            dataset = ENZYMESDataset(seed=12345)
        case "motif":
            dataset = MotifDataset(seed=12345)
        case "IMDB":
            dataset = IMDBDataset(seed=12345)
        case _:
            raise ValueError("Unknown dataset.")

    class_pair = tuple(map(int, args.class_pair.split(',')))

    ranges = [tuple(map(float, r.split(','))) for r in args.ranges.split(';')]

    target_probs = tuple(map(float, args.target_probs.split(',')))

    if args.experiment == "adjacency_graph_gen":
        exp = AdjacencyAndBoundaryGraphGenerationExperiment(
            dataset, temperature=args.temperature, ckpt_path=args.ckpt_path, num_graphs=args.num_graphs,
        )
    elif args.experiment == "adjacency":
        exp = ReproduceAdjacencyExperiment(
            dataset
        )
    elif args.experiment == "boundary_complexity":
        exp = ReproduceBoundaryComplexityExperimentOnePair(
            dataset, class_pair,
            args.num_graphs,
            temperature=args.temperature,
            graph_directory=args.graph_directory,
            lr=args.lr, ckpt_path=args.ckpt_path,
        )
    elif args.experiment == "table1":
        exp = ReproduceTable1PerDatasetExperiment(
            dataset, args.num_graphs, temperature=args.temperature, lr=args.lr, ckpt_path=args.ckpt_path,
            graph_directory=args.graph_directory,
        )
    elif args.experiment == "graph_gen":
        exp = ReproduceGraphGenerationExperiment(
            dataset, class_pair, ckpt_path=args.ckpt_path, lr=args.lr, temperature=args.temperature,
        )
    elif args.experiment == "table2":
        if args.num_runs is None:
            raise ValueError("Number of runs must be provided for table 2.")

        exp = ReproduceTable2PerDatasetExperiment(
            dataset, args.num_runs,
            num_iterations=args.num_iterations,
            ckpt_path=args.ckpt_path,
            strategy=args.strategy,
            temperature=args.temperature,
            save_dir=args.save_dir,
            lr=args.lr,
            w_budget_init=args.w_budget_init,
            w_budget_inc=args.w_budget_inc,
            w_budget_dec=args.w_budget_dec,
            max_nodes=args.max_nodes,
            target_size=args.target_size,
            random_id=args.random_id,
            target_probs=target_probs,
        )
    elif args.experiment == "complexity_ranges":
        exp = BoundaryComplexityPerTargetRangeExperiment(
            dataset,
            class_pair,
            args.num_graphs,
            ranges=ranges,
            temperature=args.temperature,
            graph_directory=args.graph_directory,
            lr=args.lr, ckpt_path=args.ckpt_path,
        )
    elif args.experiment == "margin_ranges":
        exp = BoundaryMarginPerTargetRangeExperiment(
            dataset,
            class_pair,
            args.num_graphs,
            ranges=ranges,
            temperature=args.temperature,
            graph_directory=args.graph_directory,
            lr=args.lr, ckpt_path=args.ckpt_path,#
            interpreter_load_dir=args.interpreter_directory,
            reference_class=args.reference_class
        )
    elif args.experiment == "thickness_ranges":
        exp = BoundaryThicknessPerTargetRangeExperiment(
            dataset,
            class_pair,
            args.num_graphs,
            ranges=ranges,
            temperature=args.temperature,
            graph_directory=args.graph_directory,
            lr=args.lr, ckpt_path=args.ckpt_path,
            interpreter_load_dir=args.interpreter_directory,
            reference_class=args.reference_class
        )

    elif args.experiment == "figure3":
        assert args.interpreter_directory is not None, "Interpreter directory must be provided for figure 3."
        exp = ReproduceFigure3PerDatasetExperiment(
            dataset=dataset,
            num_graphs=args.num_graphs,
            temperature=args.temperature,
            graph_directory=args.graph_directory,
            interpreter_directoy=args.interpreter_directory,
            lr=args.lr,
        )

    elif args.experiment == "embedding_space_check":
        assert class_pair is not None, "Class pair must be provided for embedding space check."
        run_embedding_trainer(dataset, args.num_embeddings, [class_pair], (0.45, 0.55),
                              -120, 150)

    else:
        raise ValueError("Unknown experiment type.")

    exp.run()
    exp.plot()
    exp.save()

if __name__ == "__main__":
    import warnings
    warnings.filterwarnings("ignore")
    main()
