from src.loader.node_dataset import NodeDataset
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat


def create_node_dataset(cfg, data, mask=None):
    if mask is None:
        mask = np.arange(data.num_nodes)
    dict_for_node_dataset_train = {
        "node_id": data.node_id[mask],
        "node_label": data.y[mask],
        "node_dataset_name": data.dataset_name[0],
        "dataset_task_name": data.dataset_task_name[0],
        "dataset_name": cfg.dataset.name,
    }
    node_dataset = NodeDataset(dict_for_node_dataset_train)
    return node_dataset


class Collate:

    def __call__(self, batch):

        node_ids = torch.tensor([item["node_id"] for item in batch], dtype=torch.long)
        node_label = torch.tensor([item["node_label"] for item in batch], dtype=torch.long)

        data = {
            "node_ids": node_ids,
            "node_label": node_label,
        }

        return data


def create_loader(dataset, cfg):

    data = dataset.data
    collate_fn = Collate()

    def worker_init_fn(worker_id):
        # Set the random seed for the worker
        torch.manual_seed(cfg.seed + worker_id)

    if cfg.dataset.task == "node":

        def create_dataloader(cfg, batch_size, data, mask, shuffle):
            NUM_WORKERS_NC = 0
            PREFETCH_FACTOR_NC = None
            PERSISTENT_WORKERS_NC = False

            return torch_DataLoader(
                create_node_dataset(cfg, data, mask),
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=NUM_WORKERS_NC,
                drop_last=False,
                # pin_memory=True,
                collate_fn=collate_fn,
                prefetch_factor=PREFETCH_FACTOR_NC,
                persistent_workers=PERSISTENT_WORKERS_NC,
                worker_init_fn=worker_init_fn,
            )

        if cfg.batch_size != -1:
            loaders = [
                create_dataloader(cfg, cfg.batch_size, data, data.train_mask, shuffle=True),
                create_dataloader(cfg, cfg.batch_size, data, data.val_mask, shuffle=False),
                create_dataloader(cfg, cfg.batch_size, data, data.test_mask, shuffle=False),
            ]

        else:
            loaders = [create_dataloader(cfg, data.num_nodes, data, mask=None, shuffle=False)]

    return loaders, data
