"""
Various utility functions.
"""

import torch
from src.t3_edge_regression.base_dataset import TNTPFlowDenoisingInterpolationDataset
from src.t4_node_classification.base_dataset import DirectedHeterophilousGraphDataset
from torch_geometric.datasets import CitationFull
import numpy as np
import torch_geometric.transforms as transforms
import math
from utils.relation_filters import (
    TSSN,
    GNN,
    DIRGNN,
    ESSN,
    MPSNN,
    TrafficSSN,
    MLP,
)


def select_face_map_induced_relations(relation_set):
    if relation_set == "MLP":
        return MLP
    if relation_set == "GNN":
        return GNN
    if relation_set == "DIRGNN":
        return DIRGNN
    if relation_set == "MPSNN":
        return MPSNN
    if relation_set == "ESSN":
        return ESSN
    if relation_set == "TSSN":
        return TSSN
    if relation_set == "TrafficSSN":
        return TrafficSSN

    raise NotImplementedError("Adjacency type not supported")


def select_simplexes_and_relations(hetero_data, max_dim, relations):
    node_types = [str(n) for n in range(max_dim + 1)]
    hetero_data = hetero_data.node_type_subgraph(node_types)
    edge_types = select_face_map_induced_relations(relations)
    hetero_data = hetero_data.edge_type_subgraph(edge_types)
    return hetero_data


def select_simplexes_and_relations_for_feat_class(
    x_dict, edge_index_dict, max_dim, relations
):
    node_types = [str(n) for n in range(max_dim + 1)]
    x_dict = x_dict.node_type_subgraph(node_types)

    edge_types = select_face_map_induced_relations(relations)
    edge_index_dict = dict(
        (k, edge_index_dict[k]) for k in edge_types if k in edge_index_dict
    )
    return x_dict, edge_index_dict


def compute_multiclass_accuracy(output, target):
    preds = output.argmax(1).type_as(target)
    correct = preds.eq(target).double()
    correct = correct.sum()
    return correct / len(target)


def get_folder_from_dset(dset):
    """
    Get the name of the folder where the dataset is stored.
    Parameters:
        dset (str): The name of the dataset.
    Returns:
        str: The name of the folder where the dataset is stored.
    """
    if dset == "cora_ml":
        return dset
    if dset == "directed-roman-empire":
        return "directed_roman_empire"
    if dset == "citeseer_full":
        return "citeseer"
    if dset.startswith(
            (
                    "traffic-anaheim",
                    "traffic-barcelona",
                    "traffic-chicago",
                    "traffic-winnipeg",
            )
    ):
        return dset.rsplit("-", 1)[0]


def get_masks(dset, path):
    """
    For edge tasks, get the masks of directed
    and undirected edges.
    Parameters:
        dset (str): The name of the dataset.
        path (str): The path to the dataset.
    Returns:
        masks (tuple of torch.tensors): The masks of directed and undirected edges.
    """
    dataset_folder_name, _ = dset.rsplit("-", 1)
    path = path + "/" + dataset_folder_name + "/undirected_masks.pt"
    masks = torch.load(path)
    return masks


def get_dataset(
        name: str,
        root_dir: str,
):
    """
    Get the dataset for the given name.
    Parameters:
        name (str): The name of the dataset.
        root_dir (str): The root directory for the dataset.
    Returns:
        dataset: The dataset object.
    """
    path = f"{root_dir}/"
    if name.startswith(
        (
            "traffic-anaheim",
            "traffic-barcelona",
            "traffic-chicago",
            "traffic-winnipeg",
        )
    ):
        dataset_kwargs = dict()
        val_ratio, test_ratio = 0.1, 0.2

        arbitrary_orientation = False
        orientation_equivariant_labels = True
        dataset_cls = TNTPFlowDenoisingInterpolationDataset
        dataset_kwargs |= dict(interpolation_label_size=0.75)

        dataset_folder_name, _ = name.rsplit("-", 1)
        path = path + dataset_folder_name

        dataset = dataset_cls(
            split="train",
            dataset_name=name,
            dataset_path=path,
            val_ratio=val_ratio,
            test_ratio=test_ratio,
            seed=42,
            arbitrary_orientation=arbitrary_orientation,
            orientation_equivariant_labels=orientation_equivariant_labels,
            preprocess=False,  # Creates the file once
            **dataset_kwargs,
        ).graphs[0]
    
    elif name == "directed-roman-empire":
        dataset = DirectedHeterophilousGraphDataset(
            name=name, transform=transforms.NormalizeFeatures(), root=path
        )

    elif name in ["cora_ml", "citeseer_full"]:
        if name == "citeseer_full":
            name = "citeseer"
        dataset = CitationFull(path, name, to_undirected=False)

    else:
        raise Exception("Unknown dataset.")

    return dataset


def get_dataset_split(name, data, root_dir, split_number):
    """
    Get the train, validation, and test masks for the given dataset.
    Parameters:
        name (str): dataset name
        data: The dataset object.
        root_dir (str): not used
        split_number (int): The split number to use.
    Returns:
        tuple of torch.tensors: The train, validation, and test masks.
    """
    if name.startswith(
        (
            "traffic-anaheim",
            "traffic-barcelona",
            "traffic-chicago",
            "traffic-winnipeg",
        )
    ):

        return (
            data.train_mask[split_number],
            data.val_mask[split_number],
            data.test_mask[split_number],
        )
    
    elif name == "directed-roman-empire":
        return (
            data["train_mask"][:, split_number],
            data["val_mask"][:, split_number],
            data["test_mask"][:, split_number],
        )
    elif name in ["cora_ml", "citeseer_full"]:
        # Uniform 50/25/25 split
        return set_uniform_train_val_test_split(
            split_number, data, train_ratio=0.5, val_ratio=0.25
        )


def set_uniform_train_val_test_split(seed, data, train_ratio=0.5, val_ratio=0.25):
    """
    Define uniformly random train/val/test splits for the dataset.
    Function from https://github.com/emalgorithm/directed-graph-neural-network/blob/main/src/datasets/data_loading.py
    """
    rnd_state = np.random.RandomState(seed)
    num_nodes = data.y.shape[0]

    # Some nodes have labels -1 (i.e. unlabeled), so we need to exclude them
    labeled_nodes = torch.where(data.y != -1)[0]
    num_labeled_nodes = labeled_nodes.shape[0]
    num_train = math.floor(num_labeled_nodes * train_ratio)
    num_val = math.floor(num_labeled_nodes * val_ratio)

    idxs = list(range(num_labeled_nodes))
    # Shuffle in place
    rnd_state.shuffle(idxs)

    train_idx = idxs[:num_train]
    val_idx = idxs[num_train: num_train + num_val]
    test_idx = idxs[num_train + num_val:]

    train_idx = labeled_nodes[train_idx]
    val_idx = labeled_nodes[val_idx]
    test_idx = labeled_nodes[test_idx]

    train_mask = get_mask(train_idx, num_nodes)
    val_mask = get_mask(val_idx, num_nodes)
    test_mask = get_mask(test_idx, num_nodes)

    return train_mask, val_mask, test_mask


def get_mask(idx, num_nodes):
    """
    Given a tensor of ids and a number of nodes, return a boolean mask of size num_nodes which is set to True at indices
    in `idx`, and to False for other indices.
    """
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask