import os.path as osp
import pandas as pd
import torch

# from torch_geometric.data import (InMemoryDataset, Data, DataLoader)
from torch_geometric.data import (InMemoryDataset, Data)
from torch.utils.data import Dataset, SequentialSampler, RandomSampler,TensorDataset
from dgl.dataloading.dataloader import DataLoader

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import dgl
# from ogb.nodeproppred import DglNodePropPredDataset
dgl.use_libxsmm(False)

from .dataset import *


def collate_fn(batch):
    """Handle graph structures and graph-level labels"""
    graphs, targets = zip(*batch)
    batched_graph = dgl.batch(graphs)
    batched_targets = torch.stack(targets)
    return batched_graph, batched_targets


def load_integrated_data(args, data_path, batch_size, mode_type='train', sample_percent=1.0):
    """Load the integrated data"""
    if mode_type == 'train':
        dataset = IntegratedDGLDataset(args, data_path, sample_percent)
    else:
        dataset = IntegratedDGLDataset(args, data_path)
    
    data_sampler = RandomSampler(dataset)
    dataloader = DataLoader(
        dataset, 
        sampler=data_sampler,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True
    )
    return dataset, dataloader

class ProcessedDataset(InMemoryDataset):
    pass


def precompute_edge_label_and_reverse(dataset: InMemoryDataset):
    data_list = []
    for data in dataset:
        for idx, dat in enumerate(data):
            u, v = dat.edge_index
            yu, yv = dat.y[u], dat.y[v]
            dat.edge_labels = yu * 2 + yv
            data_list.append(dat)

    new_data, new_slices = InMemoryDataset.collate(data_list)
    new_dataset = ProcessedDataset('.')
    new_dataset.data = new_data
    new_dataset.slices = new_slices
    return new_dataset


class CitationDataset(InMemoryDataset):
    def __init__(self, root=None, split='train', transform=None, pre_transform=None, pre_filter=None):
        assert split in ['train', 'val', 'test']
        super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)
        saved_data = pd.read_json(open(root))
        data_list = []
        for idx in saved_data.index:
            graph_data = Data(edge_index=torch.tensor(saved_data.loc[idx].edges),
                              x=torch.tensor(saved_data.loc[idx].node_feature),
                              y=torch.tensor(saved_data.loc[idx].node_target),
                              pos=torch.tensor(saved_data.loc[idx].node_lines))
            data_list.append(graph_data)
        self.data = data_list

        num_nodes = 400
        num_edges = 1000

        self.slices = {
            'x': torch.LongTensor([0, num_nodes]),
            'y': torch.LongTensor([0, num_nodes]),
            'gl': torch.LongTensor([0]),
            'edge_index': torch.LongTensor([0, num_edges])
        }


class BatchedCitationDataset(InMemoryDataset):
    def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None):
        super(BatchedCitationDataset, self).__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(root)
        num_nodes = self.data.x.size(0)
        num_edges = self.data.edge_index.size(1)
        self.data = Data(edge_index=self.data.edge_index, x=self.data.x, y=self.data.y)

        self.slices = {
            'x': torch.LongTensor([0, num_nodes]), 
            'y': torch.LongTensor([0, num_nodes]),
            'edge_index': torch.LongTensor([0, num_edges]),
            'batch': torch.LongTensor([0, num_edges])
        }

