import json
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch
import os.path as osp
import random
import numpy as np
from pathlib import Path
import os
from torch_geometric.loader import NeighborLoader, LinkNeighborLoader, DataLoader
import torch.nn.functional as F
import torch
from torchmetrics import Accuracy, AUROC
from sklearn.metrics import f1_score, roc_auc_score

def get_k_shot(data, k=-1):
    if k == -1:
        return data

    new_train_mask = torch.zeros_like(data.train_mask)
    train_labels = data.y[data.train_mask]
    unique_labels = train_labels.unique().tolist()

    for label in unique_labels:
        cls_train_idx = ((data.y == label) & data.train_mask).nonzero(as_tuple=True)[0].tolist()
        if k == 0:
            selected = []
        else:
            selected = cls_train_idx[:k]
        new_train_mask[selected] = True

    data.train_mask = new_train_mask
    return data



from ofa.gp.utils.utils import (
    load_yaml,
    combine_dict,
    merge_mod,
    setup_exp,
    set_random_seed,
)
from utils.data_utils import (
    SentenceEncoder,
    MultiApr,
    MultiAuc,
    ENCODER_DIM_DICT,
)
from task.task_constructor import UnifiedTaskConstructor

metric2order = {'loss': 'min', 'acc': 'max', 'f1': 'max', 'precision': 'max', 'recall': 'max', 'auc': 'max',
                'ap': 'max', 'mcc': 'max', 'hit': 'max', 'ndcg': 'max', 'map': 'max', 'mrr': 'max'}
task_config_lookup = load_yaml(os.path.join(os.path.dirname(__file__),"..", "config", "task_config.yaml"))
data_config_lookup = load_yaml(os.path.join(os.path.dirname(__file__),"..", "config", "data_config.yaml"))

class EarlyStopping:
    def __init__(self, patience=50):
        self.patience = patience
        self.counter = 0
        self.best_val = -np.inf
        self.best_dict = None
        self.early_stop = False

    def __call__(self, result):
        if result['val'] > self.best_val:
            self.best_val = result['val']
            self.best_dict = result
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

        return self.early_stop

class Logger:
    def __init__(self):
        self.data = {}
        self.best = {}

    def check_result(self, result):
        if 'metric' not in result:
            raise ValueError('Result must contain metric key')
        if result['metric'] not in metric2order:
            raise ValueError('Metric not supported')
        if result['train'] is None:
            result['train'] = 0
        if result['val'] is None:
            result['val'] = 0

        return result

    def log(self, run, epoch, loss, result):
        result = self.check_result(result)

        train_value = result['train']
        val_value = result['val']
        test_value = result['test']

        if run not in self.data:
            self.data[run] = {'train': [], 'val': [], 'test': []}

        self.data[run]['loss_train'] = loss
        self.data[run]['train'].append(train_value)
        self.data[run]['val'].append(val_value)
        self.data[run]['test'].append(test_value)
        self.data[run]['epoch'] = epoch

        if run not in self.best:
            self.best[run] = {'train': None, 'val': None, 'test': None}

        if metric2order[result['metric']] == 'max':
            if self.best[run]['val'] is None or val_value >= self.best[run]['val']:
                self.best[run]['train'] = train_value
                self.best[run]['val'] = val_value
                self.best[run]['test'] = test_value
                self.best[run]['epoch'] = epoch
        else:
            if self.best[run]['val'] is None or val_value <= self.best[run]['val']:
                self.best[run]['train'] = train_value
                self.best[run]['val'] = val_value
                self.best[run]['test'] = test_value
                self.best[run]['epoch'] = epoch

    def get_run_raw(self):
        return self.data

    def get_best_raw(self):
        return self.best

    def get_single_run(self, run_idx):
        return self.data[run_idx]

    def get_single_best(self, run_idx):
        return self.best[run_idx]

    def get_run(self):
        train = np.mean([np.mean(self.data[run_idx]['train']) for run_idx in self.data])
        val = np.mean([np.mean(self.data[run_idx]['val']) for run_idx in self.data])
        test = np.mean([np.mean(self.data[run_idx]['test']) for run_idx in self.data])
        return {'train': train, 'val': val, 'test': test}

    def get_best(self):
        train = [self.best[run_idx]['train'] for run_idx in self.best]
        val = [self.best[run_idx]['val'] for run_idx in self.best]
        test = [self.best[run_idx]['test'] for run_idx in self.best]

        return {'train': {'mean': np.mean(train), 'std': np.std(train)},
                'val': {'mean': np.mean(val), 'std': np.std(val)},
                'test': {'mean': np.mean(test), 'std': np.std(test)}}

def get_loader(data, split, labels, task, batch_size):
    task = task
    setting = "standard"

    if task == "node_cls":
        if setting in ['zero_shot', 'in_context']:
            train_loader = None
        else:
            train_loader = NeighborLoader(
                data,
                num_neighbors=[10] * 2,
                input_nodes=mask2idx(split["train"]),
                batch_size=batch_size,
                shuffle=True,
            )
        subgraph_loader = NeighborLoader(
            data,
            num_neighbors=[-1] * 2,
            batch_size=512,
            shuffle=False,
        )
        return train_loader, subgraph_loader

    elif task == "link_pre":
        if setting in ['zero_shot', 'in_context']:
            train_loader = None
        else:
            train_loader = LinkNeighborLoader(
                data,
                num_neighbors=[30] * 2,
                edge_label_index=data.edge_index[:, split["train"]],
                edge_label=labels[split["train"]],
                batch_size=batch_size,
                shuffle=True,
            )
        subgraph_loader = LinkNeighborLoader(
            data,
            num_neighbors=[-1] * 2,
            edge_label_index=data.edge_index,
            edge_label=labels,
            batch_size=4096,
            shuffle=False,
        )
        return train_loader, subgraph_loader

    elif task == "graph_cls":
        if setting == 'standard':
            train_dataset = data[split["train"]]
            val_dataset = data[split["valid"]]
            test_dataset = data[split["test"]]

            train_loader = DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
            )
            val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
            )
            test_loader = DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
            )
        elif setting in ['few_shot']:
            # As we only update the train_idx in sampling few-shot samples,
            # we can directly use the split["train"] as the train_idx
            # This enables the shuffle function in DataLoader.
            # The drawback is we should define the proto_loader in the finetune_graph_task function
            train_dataset = data[split["train"]]

            train_loader = DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=True,
            )
            val_loader = None
            test_loader = None

        elif setting in ['zero_shot', 'in_context']:
            train_loader = None
            val_loader = None
            test_loader = None

        return train_loader, val_loader, test_loader


    
def seed_everything(seed):
    """
    Sets the seed for multiple random number generators to ensure reproducibility across runs. 
    It also configures the behavior of the CUDA backend for deterministic output.

    Args:
        seed (int): The seed number to use for seeding the random number generators.

    Details:
        - Sets the seed for Python's built-in `random` module, NumPy's random module, and PyTorch.
        - Configures PyTorch's CUDA-related seeds for all GPUs.
        - Sets CUDA's cuDNN backend to operate deterministically, which can impact performance
          due to the disabling of certain optimizations like `benchmark` and general `enabled` status.

    Note:
        Enabling determinism can lead to a performance trade-off but is necessary for reproducibility
        when exact outcomes are critical to maintain across different runs, especially during debugging
        or testing phases.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
   

def load_finetune_mydataset_ofa(args):
    with open(args.data_config, 'r', encoding='utf-8') as file:
        dataset_config = json.load(file)
        num_clients = dataset_config["finetune_num_clients"]
        dataset_list = []
        split_list = []
        labels_list = []
        for dataset_name in dataset_config["finetune_datasets"]:
            tasks = get_task_constructor('/home/ai/xkli/GFT-main/GFT/../data')
            if dataset_name in ["cora", "pubmed", "wikics", "arxiv"]:
                dataset, split, labels, num_classes, num_tasks = get_node_data(tasks, dataset_name)
            elif dataset_name in ["WN18RR", "FB15K237"]:
                dataset, split, labels, num_classes, num_tasks = get_link_data(tasks, dataset_name)
            elif dataset_name in ["chemhiv", "chemblpre", "chempcba"]:
                dataset, split, labels, num_classes, num_tasks = get_graph_clf_graph(tasks, dataset_name)
            dataset_list.append(dataset)
            split_list.append(split)
            labels_list.append(labels)
        task_list = dataset_config["tasks"]
        weight_list = dataset_config["weights"]
    return num_clients, dataset_list, task_list, weight_list, split_list, labels_list

def refine_dataset(dataset):
    # works for molecule graphs
    if dataset.data.get("node_embs") is not None:
        dataset.data.node_text_feat = dataset.data.node_embs
        # dataset.data.node_embs = None
    if dataset.data.get("edge_embs") is not None:
        dataset.data.edge_text_feat = dataset.data.edge_embs
        # dataset.data.edge_embs = None
    if dataset.data.get("pretrain_edge_index") is not None:
        dataset.data.edge_index = dataset.data.pretrain_edge_index
        # dataset.data.pretrain_edge_index = None
    return dataset

def load_pretrain_mydataset_ofa(args, mode="pretrain"):
    with open(args.data_config, 'r', encoding='utf-8') as file:
        dataset_config = json.load(file)
        task_list = dataset_config["tasks"]
        weight_list = dataset_config["weights"]
        
        
        if mode == "pretrain":
            num_clients_each_dataset = dataset_config["pretrain_num_clients"] 
            datasets = dataset_config["datasets"]
        else:
            num_clients_each_dataset = dataset_config["finetune_num_clients"]
            datasets = dataset_config["finetune_datasets"]
            
        dataset_list = []
        split_list = []

        tasks = get_task_constructor('/home/ai/xkli/GFT-main/GFT/../data')
        
        for dataset_name in datasets:
            if dataset_name in ["cora", "pubmed", "wikics", "arxiv"]:
                dataset, split, labels, num_classes, num_tasks = get_node_data(tasks, dataset_name)
                dataset = filter_unnecessary_attrs(dataset)
            elif dataset_name in ["WN18RR", "FB15K237"]:
                dataset, split, labels, num_classes, num_tasks = get_link_data(tasks, dataset_name)
                dataset = filter_unnecessary_attrs(dataset)
            elif dataset_name in ["chemhiv"]:
                dataset, split, labels, num_classes, num_tasks = get_graph_clf_graph(tasks, dataset_name)
            elif dataset_name in ["chempcba"]:
                dataset, split, labels, num_classes, num_tasks = get_graph_clf_graph(tasks, dataset_name)
            elif dataset_name in ["chemblpre"]:
                data_config = data_config_lookup[dataset_name]
                dataset = tasks.get_ofa_data(data_config)
                labels = dataset.y
                split = {
                    'train': list(range(len(dataset))),'valid': [],'test': []
                }
                num_tasks = 1295
                num_classes = None
            
            dataset.num_tasks = num_tasks
            dataset.y = labels
            dataset = span_node_and_edge_idx(dataset)
            dataset.node_text = dataset.texts[0]
            dataset.edge_text = dataset.texts[1]
            
            
            dataset_list.append(dataset)
            split_list.append(split)
            

    
    for idx, data_tag in enumerate(dataset_list):
        data_tag.task = task_list[idx]
        data_tag.weight = weight_list[idx]
   
        
    return num_clients_each_dataset, dataset_list




def filter_unnecessary_attrs(dataset, mode="pretrain"):
    keys = [
        "x",
        "xe",
        "edge_index",
        "node_text_feat",
        "edge_text_feat",
        "class_node_text_feat",
    ]

    if mode == 'pretrain':
        keys = [
            "x",
            "xe",
            "edge_index",
            "node_text_feat",
            "edge_text_feat",
        ]

    if hasattr(dataset, "data"):
        for k, v in dataset.data.to_dict().items():
            if k not in keys:
                dataset.data[k] = None
                
    else:
        for k, v in dataset.to_dict().items():
            if k not in keys:
                dataset[k] = None
    return dataset


def span_node_and_edge_idx(dataset):
    if hasattr(dataset, "data"):
        # Define node index
        if dataset.data.x.ndim == 1:
            return dataset

        num_nodes = dataset.data.x.shape[0]
        dataset.data.x = torch.arange(num_nodes)

        # Define edge index
        num_edge_types = dataset.data.edge_text_feat.shape[0]
        num_edges = dataset.data.edge_index.shape[1]

        if num_edge_types == 1:
            dataset.data.xe = torch.zeros([num_edges], dtype=torch.long)
        else:
            dataset.data.xe = dataset.data.edge_types
    else: # local
        # Define node index
        if dataset.x.ndim == 1:
            return dataset

        num_nodes = dataset.x.shape[0]
        dataset.x = torch.arange(num_nodes)

        # Define edge index
        num_edge_types = dataset.edge_text_feat.shape[0]
        num_edges = dataset.edge_index.shape[1]

        if num_edge_types == 1:
            dataset.xe = torch.zeros([num_edges], dtype=torch.long)
        else:
            dataset.xe = dataset.edge_types
            
    return dataset

def pre_node(dataset):
    dataset = span_node_and_edge_idx(dataset)
    # dataset = filter_unnecessary_attrs(dataset)
    return dataset


def pre_link(dataset):
    dataset = span_node_and_edge_idx(dataset)
    # dataset = filter_unnecessary_attrs(dataset, mode="finetune")
    return dataset


def pre_graph(dataset):
    return dataset

def get_preprocess(task):
    if task == 'node_cls':
        return pre_node
    elif task == 'link_pre':
        return pre_link
    elif task == 'graph_cls':
        return pre_graph
    else:
        raise NotImplementedError('The task is not implemented')

def preprocess_split(split):
    if isinstance(split, dict):
        split_list = []
        if isinstance(split["test"], list):
            for train, valid, test in zip(split["train"], split["valid"], split["test"]):
                split_list.append({"train": train, "valid": valid, "test": test})
        elif split["test"].ndim == 1:
            for train, valid in zip(split["train"], split["valid"]):
                split_list.append({"train": train, "valid": valid, "test": split["test"]})
        return split_list
# def load_mini_lm(args):
#     with open(args.lm_config, 'r', encoding='utf-8') as file:
#         lm_config = json.load(file)
        
#     if lm_config["mini_lm"] == "sentence_bert":
#         from sentence_transformers import SentenceTransformer
#         mini_lm = SentenceTransformer(os.path.join(lm_config["root"], "multi-qa-distilbert-cos-v1"))
#     elif lm_config["mini_lm"] == "deberta":
#         from sentence_transformers import SentenceTransformer
#         mini_lm = SentenceTransformer(os.path.join(lm_config["root"], "deberta-v3-base"))
#     elif lm_config["mini_lm"] == "roberta":
#         from sentence_transformers import SentenceTransformer
#         mini_lm = SentenceTransformer(os.path.join(lm_config["root"], "all-MiniLM-L6-v2"))
#     elif lm_config["mini_lm"] == "tf-idf":
#         from model.non_param_lang_model import TFIDFModel
#         mini_lm = TFIDFModel()
#     elif lm_config["mini_lm"] == "word2vec":
#         from model.non_param_lang_model import Word2VecModel
#         mini_lm = Word2VecModel()
#     return mini_lm

def load_llm(args):
    with open(args.lm_config, 'r', encoding='utf-8') as file:
        lm_config = json.load(file)
    
    if lm_config["llm"] == "llama2-7b":        
        llm_tokenizer = LlamaTokenizer.from_pretrained(os.path.join(lm_config["root"], "Llama-2-7b-hf"), add_eos_token=True, padding_side="left")
        llm_tokenizer.pad_token = llm_tokenizer.eos_token
        llm_model = LlamaForCausalLM.from_pretrained(os.path.join(lm_config["root"], "Llama-2-7b-hf"))
    elif lm_config["llm"] == "e5-large-v2":
        from transformers import AutoModel, AutoTokenizer
        llm_tokenizer = AutoTokenizer.from_pretrained(os.path.join(lm_config["root"], "e5-large-v2"), add_eos_token=True, padding_side="left")
        llm_tokenizer.add_special_tokens({"pad_token": "<pad>"})
        llm_model = AutoModel.from_pretrained(os.path.join(lm_config["root"], "e5-large-v2"))
    elif lm_config["llm"] == "llama2-13b":
        llm_tokenizer = LlamaTokenizer.from_pretrained(os.path.join(lm_config["root"], "Llama-2-13b-hf"), add_eos_token=True, padding_side="left")
        llm_tokenizer.pad_token = llm_tokenizer.eos_token
        llm_model = LlamaForCausalLM.from_pretrained(os.path.join(lm_config["root"], "Llama-2-13b-hf"))
        
    return llm_tokenizer, llm_model

        
    
    
        
    
def check_path(path):
    if not osp.exists(path):
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
    return path    
    
def accuracy(y_pred, y_true):
    y_pred = torch.argmax(y_pred, dim=1)
    return torch.sum(y_pred == y_true).item() / len(y_true)

def load_lm(args):
    with open(args.lm_config, 'r', encoding='utf-8') as file:
        lm_config = json.load(file)

def mask2idx(mask):
    return torch.where(mask == True)[0]

def get_device_from_model(model):
    return next(model.parameters()).device
def sample_proto_instances(labels, split, num_instances_per_class=10):
    y = labels.cpu().numpy()
    target_y = y[split.detach().cpu()]
    classes = np.unique(target_y)

    class_index = []
    for i in classes:
        c_i = np.where(y == i)[0]
        c_i = np.intersect1d(c_i, split)
        class_index.append(c_i)

    proto_idx = np.array([])

    for idx in class_index:
        np.random.shuffle(idx)
        proto_idx = np.concatenate((proto_idx, idx[:num_instances_per_class]))

    return proto_idx.astype(int)
task2metric = {'node_cls': 'acc', 'link_pre': 'acc', 'graph_cls': 'auc'}
def evaluate(pred, y, task, mask=None, params=None):
    if mask is not None and mask.sum() == 0:
        return -999
    
    metric = task2metric[task]

    if metric == 'acc':
        return eval_acc(pred, y, mask) * 100
    elif metric == 'auc':
        return eval_auc(pred, y) * 100
    else:
        raise ValueError(f"Metric {metric} is not supported.")


def eval_acc(y_pred, y_true, mask):
    device = y_pred.device
    num_classes = y_pred.size(1)

    evaluator = Accuracy(task="multiclass", num_classes=num_classes).to(device)

    if mask is not None:
        return evaluator(y_pred[mask], y_true[mask]).item()
    else:
        return evaluator(y_pred, y_true).item()


def eval_auc(y_pred, y_true):
    y_pred = y_pred.detach().cpu().numpy()
    y_true = y_true.detach().cpu().numpy()

    roc_list = []
    y_true[y_true == -1] = 0
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            is_valid = y_true[:, i] == y_true[:, i]
            roc_list.append(roc_auc_score(y_true[is_valid, i], y_pred[is_valid, i]))

    # if len(roc_list) < y_true.shape[1]:
    #     print("Some target is missing!")
    #     print("Missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list)  # y_true.shape[1]

import torch


# def sample_proto_instances_for_graph(labels, split, num_instances_per_class=10):
#     y = labels
#     ndim = y.ndim
#     if ndim == 1:
#         y = y.reshape(-1, 1)

#     # Map class and instance indices

#     if isinstance(y, torch.Tensor):
#         y = y.cpu().numpy()
#     target_y = y[split]
#     task_list = target_y.shape[1]

#     # class_index_pos = {}
#     # class_index_neg = {}
#     task_index_pos, task_index_neg = [], []
#     for i in range(task_list):
#         c_i = np.where(y[:, i] == 1)[0]
#         c_i = np.intersect1d(c_i, split)
#         task_index_pos.append(c_i)

#         c_i = np.where(y[:, i] == 0)[0]
#         c_i = np.intersect1d(c_i, split)
#         task_index_neg.append(c_i)

#     assert len(task_index_pos) == len(task_index_neg)

#     # Randomly select instances for each task

#     proto_idx, proto_labels = {}, {}
#     for task, (idx_pos, idx_neg) in enumerate(zip(task_index_pos, task_index_neg)):
#         tmp_proto_idx, tmp_labels = np.array([]), np.array([])

#         # Randomly select instance for the task

#         np.random.shuffle(idx_pos)
#         np.random.shuffle(idx_neg)
#         idx_pos = idx_pos[:num_instances_per_class]
#         idx_neg = idx_neg[:num_instances_per_class]

#         # Store the randomly selected instances

#         tmp_proto_idx = np.concatenate((tmp_proto_idx, idx_pos))
#         tmp_labels = np.concatenate((tmp_labels, np.ones(len(idx_pos))))
#         tmp_proto_idx = np.concatenate((tmp_proto_idx, idx_neg))
#         tmp_labels = np.concatenate((tmp_labels, np.zeros(len(idx_neg))))

#         proto_idx[task] = tmp_proto_idx.astype(int)
#         proto_labels[task] = tmp_labels.astype(int)

#     return proto_idx, proto_labels

def sample_proto_instances_for_graph(labels, num_tasks, split, num_instances_per_class=10):

    y = labels.view(-1, num_tasks)
   
    split_indices = torch.nonzero(split, as_tuple=True)[0] 

    sampled_indices = {}
    
    for task in range(num_tasks):
        task_labels = y[:, task] 
        

        task_sampled_indices = {}
       
        for class_label in [0, 1]: 

            class_indices = split_indices[task_labels[split_indices] == class_label]
            if len(class_indices) > 0:

                sampled_class_indices = class_indices[torch.randperm(len(class_indices))[:num_instances_per_class]]
                task_sampled_indices[class_label] = sampled_class_indices
        
        if len(task_sampled_indices) != 0:
            sampled_indices[task] = task_sampled_indices
    
    return sampled_indices





llm_name = "ST"
def get_task_constructor(data_path):
    # Load processed_params.yaml
    encoder = SentenceEncoder(llm_name, batch_size=1)
    task_names = ['cora_link', 'cora_node', 'pubmed_link', 'pubmed_node', 'arxiv', 'WN18RR', 'FB15K237', 'wikics', 'chemblpre', 'chempcba', 'chemhiv']

    if isinstance(task_names, str):
        task_names = [a.strip() for a in task_names.split(",")]
    else:
        task_names = task_names

    root = data_path
    if llm_name != "ST":
        root = f"{data_path}_{llm_name}"

    tasks = UnifiedTaskConstructor(
        task_names,
        encoder,
        task_config_lookup,
        data_config_lookup,
        root=root,
        batch_size=512,
        sample_size=-1,
    )

    return tasks


def get_graph_clf_graph(tasks, dataset_name):
    data_config = data_config_lookup[dataset_name]
    dataset = tasks.get_ofa_data(data_config)
    split = tasks.get_data_split(data_config)

    if dataset_name in ["chemhiv"]:
        num_tasks = 1
        num_classes = None
        labels = dataset.y
    elif dataset_name in ["chempcba"]:
        num_tasks = 128
        num_classes = None
        labels = dataset.y.reshape(-1, num_tasks)
    elif dataset_name in ["chemblpre"]:
        raise NotImplementedError(f"Dataset {dataset_name} is only used for pre-training")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for graph classification task")

    return dataset, split, labels, num_classes, num_tasks

def get_node_data(tasks, dataset_name):
    data_config = data_config_lookup[dataset_name]
    dataset = tasks.get_ofa_data(data_config)
    data = dataset[0]

    num_tasks = 1

    if dataset_name in ["cora", "pubmed"]:
        split = {"train": data.train_masks, "valid": data.val_masks, "test": data.test_masks}
        labels = data.y
        num_classes = labels.unique().shape[0]

    elif dataset_name in ["wikics"]:
        split = {"train": data.train_mask.T, "valid": data.val_mask.T, "test": data.test_mask.T}
        labels = data.y
        num_classes = labels.unique().shape[0]

    elif dataset_name in ["arxiv"]:
        split = {"train": data.train_mask, "valid": data.val_mask, "test": data.test_mask}
        labels = data.y.squeeze()
        num_classes = labels.unique().shape[0]

    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for node classification task")

    return dataset, split, labels, num_classes, num_tasks

def get_link_data(tasks, dataset_name):
    if dataset_name in ["WN18RR", "FB15K237"]:
        data_config = data_config_lookup[dataset_name]
        dataset = tasks.get_ofa_data(data_config)
        split = tasks.get_data_split(data_config)
        num_tasks = 1

        data = dataset[0]

        labels = data.edge_types
        num_classes = labels.unique().shape[0]

    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for link classification task")

    return dataset, split, labels, num_classes, num_tasks
def index_to_mask(index_tensor, num_samples):

    mask_tensor = torch.zeros(num_samples, dtype=torch.bool)

    mask_tensor[index_tensor] = True
    
    return mask_tensor


def construct_graph(x):
    graph = [x, torch.tensor([]).view(2,-1).to(x.device).long(), None]
    return graph