from gnnboundary import *
from scripts.experiments import get_model_kwargs, CKPT_PATHS
import torch
import umap
import numpy as np
import pandas as pd
import glob
import os
from torch_geometric.data import Batch, Dataset
from sklearn.decomposition import PCA


def get_embeddings_for_class_pair(
    class_pair: tuple[int], dataset: Dataset, our_embeddings: bool = True
) -> list[torch.Tensor]:
    """
    Return the embeddings for the class graphs and the boundary graphs for a
    given class pair of a given dataset instance.
    our_embeddings is if the embeddings created using boundary graph search on the embedding space should be included.
    Those would be expected to be stored in csv in graphs/embeddings/dataset_name.
    """

    # Load classifier
    model = GCNClassifier(**get_model_kwargs(dataset, dataset.name))
    model.load_state_dict(torch.load(CKPT_PATHS[dataset.name]))

    # Get class graph embeddings based on dataset
    dataset_list_gt = dataset.split_by_class()
    embeds = [
        d.model_transform(model, key="embeds_last").numpy()
        for class_idx, d in enumerate(dataset_list_gt)
        if class_idx in class_pair
    ]

    # Load in the samplers for the given class pair
    # Include both possible orders of the class pair.
    all_pt_files = []
    cp_combinations = [
        "-".join(map(str, sorted(class_pair))),
        "-".join(map(str, sorted(class_pair, reverse=True))),
    ]
    for cp_ordered in cp_combinations:
        boundary_graph_root = f"graphs/boundary/{dataset.name}/{cp_ordered}"
        if os.path.exists(boundary_graph_root):
            all_pt_files.extend(
                glob.glob(
                    os.path.join(boundary_graph_root, "**", "*.pt"), recursive=True
                )
            )
    else:
        if len(all_pt_files) == 0:
            print(f"No graphs could be found for {class_pair}")

    # Load in the samplers and sample.
    boundary_graphs = []
    sampler = GraphSampler(
        max_nodes=25,
        temperature=0.15,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=True,
    )
    for file in all_pt_files:
        if not os.path.exists(file):
            raise FileNotFoundError(f"The directory {file} does not exist.")
        sampler.load_state_dict(torch.load(file))
        boundary_graphs.append(sampler(k=1, mode="discrete", expected=True))

    boundary_embeds = embed_boundary_graphs(boundary_graphs, model)
    embeds.append(boundary_embeds)

    # Include the boundary embeddings using our method.
    if our_embeddings:
        for cp_ordered in cp_combinations:
            embedding_path = f"graphs/embeddings/{dataset.name}/{cp_ordered}.csv"
            if os.path.exists(embedding_path):
                embd_df = pd.read_csv(embedding_path, header=None, index_col=False)
                embd_np = embd_df.to_numpy()
                embeds.append(embd_np)
                break
    else:
        print(f"The embeddings for class pair {class_pair} do not exist for {dataset.name}!")

    return embeds


def embed_boundary_graphs(data: list[Batch], model: GCNClassifier) -> torch.tensor:
    "Embed boundary graphs given a model"

    model.eval()

    with torch.no_grad():
        embeddings = [model(graph)["embeds_last"] for graph in data]

    return torch.cat(embeddings, dim=0)


def generate_umap(
    embeddings: list[torch.Tensor],
    dataset: Dataset,
    class_pair: tuple[int],
    n_components: int = 3,
    random_state: int = 42,
) -> tuple:
    """
    Using a set of GCN embedded graphs, generate the UMAP embedding and return the UMAP embeddings.
    """

    embeddings_np = np.vstack(embeddings)  # Shape: [1500, 128] (500 per class)
    label_order = [
        dataset.GRAPH_CLS[class_pair[0]],
        dataset.GRAPH_CLS[class_pair[1]],
        "Boundary Graphs",
    ]
    if len(embeddings) == 4:
        label_order.append("Our Boundary Graphs")

    labels = np.concatenate(
        [
            np.full(embeddings[i].shape[0], label_order[i])
            for i in range(len(label_order))
        ]
    )

    reducer = umap.UMAP(n_components=n_components, random_state=random_state)
    return reducer.fit_transform(embeddings_np), labels


def generate_pca(
    embeddings: list[torch.Tensor],
    dataset: Dataset,
    class_pair: tuple[int],
    n_components: int = 3,
    random_state: int = 42,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Using a set of GCN embedded graphs, generate the PCA embedding and return the PCA embeddings.
    
    Args:
        embeddings (list[torch.Tensor]): List of tensors representing graph embeddings.
        dataset (Dataset): Dataset containing class information.
        class_pair (tuple[int]): Pair of class indices.
        n_components (int, optional): Number of PCA components. Defaults to 3.
        random_state (int, optional): Random seed for reproducibility. Defaults to 42.
    
    Returns:
        tuple[np.ndarray, np.ndarray]: PCA-transformed embeddings and corresponding labels.
    """

    embeddings_np = np.vstack(embeddings)  # Shape: [1500, 128] (500 per class)
    label_order = [
        dataset.GRAPH_CLS[class_pair[0]],
        dataset.GRAPH_CLS[class_pair[1]],
        "Boundary Graphs",
    ]
    if len(embeddings) == 4:
        label_order.append("Our Boundary Graphs")
    
    labels = np.concatenate(
        [
            np.full(embeddings[i].shape[0], label_order[i])
            for i in range(len(label_order))
        ]
    )

    reducer = PCA(n_components=n_components, random_state=random_state)
    return reducer.fit_transform(embeddings_np), labels

