import pandas as pd
import torch
import torch_geometric
import networkx as nx
from gnn_xai_common.models import GCNClassifier
from gnnboundary import *

import random
import math
import os
from collections import defaultdict
import argparse
import logging


# Setup logging
logging.basicConfig(
    level=logging.INFO, format="%(process)d - %(levelname)s - %(message)s"
)


DATASETS = [MotifDataset, CollabDataset, ENZYMESDataset, IMDBDataset]
PAPER_CLASS_COMBINATIONS = {
    "MotifDataset": {(0, 1), (0, 2), (1, 3)},
    "CollabDataset": {(0, 1), (0, 2)},
    "ENZYMESDataset": {(0, 3), (0, 4), (0, 5), (1, 2), (3, 4), (4, 5)},
    "RedditDataset": {(1, 2), (1, 4), (2, 4), (3, 4)},
    "IMDBDataset": {(0, 1), (0, 2)},
}  # The ones from the GNNBoundary paper.
MODEL_CONFIGS = {
    "MotifDataset": {"hidden_channels": 6, "num_layers": 3},
    "CollabDataset": {"hidden_channels": 64, "num_layers": 5},
    "ENZYMESDataset": {"hidden_channels": 32, "num_layers": 3},
    "IMDBDataset": {"hidden_channels": 64, "num_layers": 5},
}
CLASS_MAPPING = {
    "MotifDataset": {0: "House", 1: "House-X", 2: "Comp-4", 3: "Comp-5"},
    "CollabDataset": {0: "High Energy", 1: "Condensed Matter", 2: "Astro"},
    "ENZYMESDataset": {0: "EC1", 1: "EC2", 2: "EC3", 3: "EC4", 4: "EC5", 5: "EC6"},
    "IMDBDataset": {0: "Comedy", 1: "Romance", 2: "SciFi"},
}


def sample_graphs_by_class(
    dataset: torch_geometric.data.Data,
    cls_label: int,
    num_samples: int = 1,
    source: str = "GNNInterpreter",
):
    "Sample a given number of graphs from the dataset with a specific class label."

    assert source in [
        "GNNInterpreter",
        "Dataset",
    ], f"The source or class_sample parameter should be GNNInterpreter or Dataset, {source} is not allowed!"
    if source == "Dataset":
        cls_graphs = [graph for graph in dataset if graph.y == cls_label]
    elif source == "GNNInterpreter":
        # Load the previously generated graphs.
        directory_path = f"graphs/interpreter/{dataset.name}/{cls_label}"
        if not os.path.exists(directory_path):
            raise FileNotFoundError(f"The directory {directory_path} does not exist.")

        # List all .pt files in the directory
        pt_files = [f for f in os.listdir(directory_path) if f.endswith(".pt")]

        if not pt_files:
            raise FileNotFoundError(
                f"No .pt files found in the directory {directory_path}."
            )

        cls_graphs = [f for f in os.listdir(directory_path) if f.endswith(".pt")]

    if len(cls_graphs) < num_samples:
        factor = math.ceil(num_samples / len(cls_graphs))
        counts = [factor for _ in range(len(cls_graphs))]
    else:
        counts = None

    sample = random.sample(cls_graphs, num_samples, counts=counts)

    # Now actually generate the graphs from the sampler checkpoints if source is GNNInterpreter.
    if source == "Dataset":
        sampled_graphs = sample
    elif source == "GNNInterpreter":
        sampled_paths = sample
        samplers = []
        for path in sampled_paths:
            sampler = GraphSampler(
                temperature=0.15,
                max_nodes=25 if dataset.name == "COLLAB" and cls_label == 0 else 20,
                num_node_cls=len(dataset.NODE_CLS),
                learn_node_feat=True,
            )
            try:
                sampler.load_state_dict(torch.load(os.path.join(directory_path, path)))
            except Exception as e:
                print(e)
                # print(f"The sample in class {cls_label} of {dataset.name}: \n {path}")
            samplers.append(sampler)

        sampled_graphs = [sampler(k=1, mode='discrete', expected=True) for sampler in samplers]

    return sampled_graphs


def connect_graphs_batch(
    graphs1: list[torch_geometric.data.Data],
    graphs2: list[torch_geometric.data.Data],
    num_connections: int = 1,
) -> list[torch_geometric.data.Data]:
    """
    Connect batches of graphs in `graphs1` and `graphs2` with `num_connections` random edges each.
    """

    connected_graphs = []

    for g1, g2 in zip(graphs1, graphs2):
        node_g1 = torch.tensor(
            random.sample(range(g1.num_nodes), num_connections), dtype=torch.long
        )
        node_g2 = (
            torch.tensor(
                random.sample(range(g2.num_nodes), num_connections), dtype=torch.long
            )
            + g1.num_nodes
        )
        merged_nodes = torch.cat((g1.x, g2.x), dim=0)
        merged_edges = torch.cat((g1.edge_index, g2.edge_index), dim=1)
        new_edges = torch.stack((node_g1, node_g2), dim=0)
        edge_index_new = torch.cat((merged_edges, new_edges), dim=1)

        if (isinstance(graphs1, nx.classes.graph.Graph)):
            # Combine NetworkX graphs and add connection to them as well.
            g2_offset = nx.relabel_nodes(g2.G, {n: n + g1.num_nodes for n in g2.G.nodes()})
            merged_graph_nx = nx.compose(g1.G, g2_offset)
            merged_graph_nx.name = (
                f"Boundary Graph from connect randomly from classes {int(g1.y)}"
            )
            for u, v in zip(node_g1.tolist(), node_g2.tolist()):
                merged_graph_nx.add_edge(u, v)
            new_g = torch_geometric.data.Data(x=merged_nodes, edge_index=edge_index_new)
            new_g.G = merged_graph_nx
        else:
            new_g = torch_geometric.data.Data(x=merged_nodes, edge_index=edge_index_new)

        connected_graphs.append(new_g)

    return connected_graphs


def generate_boundary_graphs(
    dataset: torch_geometric.data.Data,
    num_graphs: int = 500,
    num_connections: int = 1,
    class_combinations: list[tuple] = None,
    class_samples: str = "GNNInterpreter",
) -> dict:
    """
    Generate num_graphs boundary graphs by sampling a corresponding batch from the dataset and connecting
    them with num_connections random edges. The result is stored in a dictionary, where the
    key is a tuple of the pair of classes and the key is the batch of connected boundary graphs.
    If class_combinations is given, we only generate graphs for those combinations.
    """

    if not class_combinations:
        num_classes = len(dataset.GRAPH_CLS)
        class_combinations = [
            (i, j) for i in range(num_classes) for j in range(i + 1, num_classes)
        ]

    boundary_graphs = defaultdict(list)

    for cls1, cls2 in class_combinations:
        graphs_cls1 = sample_graphs_by_class(
            dataset, cls1, num_graphs, source=class_samples
        )
        graphs_cls2 = sample_graphs_by_class(
            dataset, cls2, num_graphs, source=class_samples
        )
        connected_graphs = connect_graphs_batch(
            graphs_cls1, graphs_cls2, num_connections
        )
        boundary_graphs[(cls1, cls2)] = connected_graphs

    return boundary_graphs


@torch.no_grad()
def evaluate_boundary_graphs(boundary_graphs: dict, model: GCNClassifier):
    "Evaluate a batch of boundary graphs for each label combination given a model."

    results = {}
    for cls1, cls2 in boundary_graphs.keys():
        batch = torch_geometric.data.Batch.from_data_list(boundary_graphs[cls1, cls2])
        result = model(batch)
        probs = result["probs"]
        mean = probs.mean(dim=0)
        std = probs.std(dim=0)
        results[(cls1, cls2)] = {
            "mean": (mean[cls1].item(), mean[cls2].item()),
            "std": (std[cls1].item(), std[cls2].item()),
        }

    return results


def run_random_baseline(
    datasets: list = DATASETS,
    class_combinations: dict = PAPER_CLASS_COMBINATIONS,
    num_boundary_graphs: int = 500,
    num_connections: int = 1,
    class_samples: str = "GNNInterpreter",
) -> pd.DataFrame:
    """
    Run the random baseline experiments for given datasets and the corresponding given class combinations.
    class_samples can be GNNInterpreter for which the graphs are loaded from the interpreter directory or
    Dataset for which the graphs are sampled from the corresponding dataset.
    Save the results and return as pandas DataFrame.
    """

    results = {}
    dataset_classes = []
    for dataset in datasets:
        if isinstance(dataset, str):
            dataset_classes.append(eval(dataset))
        else:
            dataset_classes.append(dataset)

    for dataset_class in dataset_classes:
        dataset = dataset_class(seed=12345)
        dataset_name = dataset_class.__name__
        boundary_graphs = generate_boundary_graphs(
            dataset,
            num_graphs=num_boundary_graphs,
            num_connections=num_connections,
            class_combinations=class_combinations[dataset_name],
            class_samples=class_samples,
        )
        checkpoint_name = dataset_name.replace("Dataset", "").lower()
        model = GCNClassifier(
            node_features=len(dataset.NODE_CLS),
            num_classes=len(dataset.GRAPH_CLS),
            hidden_channels=MODEL_CONFIGS[dataset_name]["hidden_channels"],
            num_layers=MODEL_CONFIGS[dataset_name]["num_layers"],
        )
        model.load_state_dict(torch.load(f"ckpts/{checkpoint_name}.pt"))
        result = evaluate_boundary_graphs(boundary_graphs, model)
        results[dataset_name] = result

    # Prepare the data for the DataFrame
    rows = []
    for dataset_name, dataset_results in results.items():
        for cls_tuple, metrics in dataset_results.items():
            row = (dataset_name, cls_tuple, "p(c1)", "mean", metrics["mean"][0])
            rows.append(row)
            row = (dataset_name, cls_tuple, "p(c1)", "std", metrics["std"][0])
            rows.append(row)
            row = (dataset_name, cls_tuple, "p(c2)", "mean", metrics["mean"][1])
            rows.append(row)
            row = (dataset_name, cls_tuple, "p(c2)", "std", metrics["std"][1])
            rows.append(row)

    # Create a MultiIndex DataFrame
    df = pd.DataFrame(
        rows, columns=["Dataset", "Class Tuple", "Probability", "Statistic", "Value"]
    )
    df = df.pivot(
        columns=["Probability", "Statistic"],
        index=["Dataset", "Class Tuple"],
        values="Value",
    )

    # Change dataset order to match with GNNBoundary paper
    if set(df.index.get_level_values("Dataset")) == {"MotifDataset", "CollabDataset", "ENZYMESDataset", "IMDBDataset"}:
        df = df.sort_index(
            level="Dataset",
            key=lambda x: x.map(
                {"MotifDataset": 0, "CollabDataset": 1, "ENZYMESDataset": 2, "IMDBDataset": 3}
            ),
        )

    # Apply class mapping
    def mapper(idx):
        dataset, class_tuple = idx
        dataset_mapping = CLASS_MAPPING.get(dataset, {})
        return (dataset, " || ".join(dataset_mapping.get(c, c) for c in class_tuple))

    new_tuples = [mapper(idx) for idx in df.index]
    df.index = pd.MultiIndex.from_tuples(new_tuples, names=df.index.names)

    df.to_csv(f"results/random_baseline_{class_samples}.csv")

    return df


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the random baseline function with custom parameters."
    )

    parser.add_argument(
        "--datasets",
        type=str,
        nargs="+",
        default=None,
        help="List of datasets to use. Default is DATASETS.",
    )
    parser.add_argument(
        "--class_combinations",
        type=str,
        default=None,
        help="Path to a JSON file containing class combinations. Default is PAPER_CLASS_COMBINATIONS.",
    )
    parser.add_argument(
        "--num_boundary_graphs",
        type=int,
        default=500,
        help="Number of boundary graphs to use. Default is 500.",
    )
    parser.add_argument(
        "--num_connections",
        type=int,
        default=1,
        help="Number of connections to use. Default is 1.",
    )
    parser.add_argument(
        "--class_samples",
        type=str,
        default="GNNInterpreter",
        help='Either "Dataset" or "GNNInterpreter" to choose the class graph sampling method.',
    )

    args = parser.parse_args()

    # Load class_combinations from JSON if provided
    if args.class_combinations:
        import json

        with open(args.class_combinations, "r") as f:
            class_combinations = json.load(f)
    else:
        class_combinations = PAPER_CLASS_COMBINATIONS  # Use default

    # Use provided datasets or default
    datasets = args.datasets if args.datasets else DATASETS

    # Call the function
    result = run_random_baseline(
        datasets=datasets,
        class_combinations=class_combinations,
        num_boundary_graphs=args.num_boundary_graphs,
        num_connections=args.num_connections,
        class_samples=args.class_samples,
    )

    # Print or save the result
    logging.info("The following results were saved as csv in the results directory:")
    logging.info(f"\n{result}")
