from torch_geometric.loader import NeighborLoader, LinkNeighborLoader, DataLoader
from operator import itemgetter
from utils.others import mask2idx
from torch_geometric.data.collate import collate
import os
import torch


def my_collate(data_list):
    num_nodes = 0
    for data in data_list:
        data.edge_index = data.edge_index + num_nodes
        num_nodes += len(data.x)
    return collate(data_list)
    

class MultiGraphLoader(DataLoader):
    def __init__(self, params, pretrain_data, batch_size, shuffle):
        path = os.path.join("./cache_data",params["pretrain_dataset"],
                            params["graph_llm_name"]+"_"+params["lang_llm_name"],"processed",
                            "geometric_data_processed.pt")
        _ , _ , data_list = torch.load(path)
        super().__init__(data_list, batch_size = batch_size, shuffle = shuffle)
        # new_data_list = []
        for datas in self:
            # datas,_ = my_collate(datas)
            datas.append(pretrain_data.node_text_feat)
            datas.append(pretrain_data.edge_text_feat)
        # super().__init__(new_data_list, batch_size = 1, shuffle = shuffle)


def get_loader(data, split, labels, params):
    task = params['task']
    setting = params["setting"]

    if task == "node":
        if setting in ['zero_shot', 'in_context'] or params["train_batch_size"] == 0:
            train_loader = None
        else:
            train_loader = NeighborLoader(
                data,
                input_nodes=mask2idx(split["train"]),
                num_neighbors=[params["num_neighbors"]] * params["num_layers"],
                batch_size=params["train_batch_size"],
                shuffle=True
            )

        if params["eval_batch_size"] == 0:
            subgraph_loader = None
        else:
            subgraph_loader = NeighborLoader(
                data,
                num_neighbors=[-1] * params["num_layers"],
                batch_size=params["eval_batch_size"],
                shuffle=False
            )

        return train_loader, subgraph_loader

    elif task == "link":
        if setting in ['zero_shot', 'in_context'] or params["train_batch_size"] == 0:
            train_loader = None
        else:
            train_loader = LinkNeighborLoader(
                data,
                edge_label_index=data.edge_index[:, split["train"]],
                num_neighbors=[params["num_neighbors"]] * params["num_layers"],
                edge_label=labels[split["train"]],
                batch_size=params["train_batch_size"],
                shuffle=True,
            )

        if params["eval_batch_size"] == 0:
            subgraph_loader = None
        else:
            subgraph_loader = LinkNeighborLoader(
                data,
                num_neighbors=[-1] * params["num_layers"],
                edge_label_index=data.edge_index,
                edge_label=labels,
                batch_size=params["eval_batch_size"],
                shuffle=False,
            )
            
        return train_loader, subgraph_loader

    elif task == "graph":
        if setting == 'standard':
            train_dataset = data[split["train"]]
            val_dataset = data[split["valid"]]
            test_dataset = data[split["test"]]

            train_loader = DataLoader(
                train_dataset,
                batch_size=params["train_batch_size"],
                shuffle=True
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=params["eval_batch_size"],
                shuffle=False
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=params["eval_batch_size"],
                shuffle=False
            )

        elif setting in ['few_shot']:
            # As we only update the train_idx in sampling few-shot samples,
            # we can directly use the split["train"] as the train_idx
            # This enables the shuffle function in DataLoader.
            # The drawback is we should define the proto_loader in the finetune_graph_task function
            train_dataset = data[split["train"]]

            train_loader = DataLoader(
                train_dataset,
                batch_size=params["train_batch_size"],
                shuffle=True
            )
            val_loader = None
            test_loader = None

        elif setting in ['zero_shot', 'in_context']:
            train_loader = None
            val_loader = None
            test_loader = None

        return train_loader, val_loader, test_loader
    
    elif task == "GQA":
        if setting == 'standard':
            if params["finetune_dataset"] == "scene_graphs":
                train_dataset = list(itemgetter(*split["train"])(data))
                val_dataset = list(itemgetter(*split["val"])(data))
                test_dataset = list(itemgetter(*split["test"])(data))
            else:
                train_dataset = data[split["train"]]
                val_dataset = data[split["val"]]
                test_dataset = data[split["test"]]

            train_loader = DataLoader(
                train_dataset,
                batch_size=params["train_batch_size"],
                shuffle=True,
                num_workers=0,
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=params["eval_batch_size"],
                shuffle=False,
                num_workers=0,
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=params["eval_batch_size"],
                shuffle=False,
                num_workers=0,
            )

            return train_loader, val_loader, test_loader