import argparse
import torch
import json
import os
import os.path as osp
import pandas as pd
import numpy as np
from multiprocessing import Pool
from torch_geometric.data import Data, Dataset, download_url, extract_zip
from torch_geometric.utils.convert import to_networkx
from torch_geometric.loader import GraphSAINTEdgeSampler
import networkx as nx
import numpy as np
from tqdm import tqdm


class Amazon(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, num_workers=32):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.num_workers = num_workers

    @property
    def raw_file_names(self):
        return ["edges.npy", "features.npy", "node_labels", "role"]

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def num_classes(self):
        path = osp.join(self.raw_dir, 'node_labels')
        node_labels = pd.read_csv(path, sep=',', names=['idx', 'label'])
        np_str_lab = node_labels[['label']].to_numpy().astype(str)
        return len(np.unique(np_str_lab))

    def download(self):
        raise NotImplementedError

    def load_features(self, filename):
        features = np.load(osp.join(self.raw_dir, 'Features', filename))
        return features

    def get_data(self):
        print("Loading Node Labels", end="...")
        path = osp.join(self.raw_dir, 'node_labels')
        node_labels = pd.read_csv(path, sep=',', names=['idx', 'label'])
        print("Done")

        print("Loading Edges", end="...")
        path = osp.join(self.raw_dir, 'edges.npy')
        edge_idx = np.load(path)
#         edge_index = pd.read_csv(path, sep=',', names=['source', 'target'])
        print("Done")

        attributes = np.load(osp.join(self.raw_dir, 'features.npy'))

        return attributes, edge_idx, node_labels

    def get_graph(self):
        attrs, edge_idx, node_labels = self.get_data()

        edge_idx = torch.tensor(edge_idx.transpose(), dtype=torch.long)
        map_dict = {v.item():i for i,v in enumerate(torch.unique(edge_idx))}

        node_idx = node_labels[['idx']].to_numpy()
        np_str_lab = node_labels[['label']].to_numpy().astype(str)
        classes, labels = np.unique(np_str_lab, return_inverse=True)

        print(edge_idx.numpy().shape)
        print(labels.shape)
        print(attrs.shape)

        x = torch.tensor(attrs, dtype=torch.float)
        y = torch.tensor(labels, dtype=torch.long).squeeze(-1)

        print(f"{len(attrs)} valid nodes")
        print(f"with {len(classes)} classes")
        print()

        return x, y, edge_idx

    def get_masks(self, attr_idx):
        print("Loading Masks...")
        with open(osp.join(self.raw_dir, 'role')) as file:
            role = json.load(file)

        print("Creating Masks", end="...")
        train, val, test = np.array(role['train']), np.array(role['val']), np.array(role['test'])

        train_mask = np.in1d(attr_idx, train)
        val_mask = np.in1d(attr_idx, val)
        test_mask = np.in1d(attr_idx, test)
        print("Done")

        return train_mask, val_mask, test_mask

    def process(self):
        x, y, edge_idx = self.get_graph()
        train_mask, val_mask, test_mask = self.get_masks(np.arange(len(x)))

        print(f"{len(x)} processed attributes")
        print(f"with {len(y)} node labels")
        print(f"all masked to {np.sum(train_mask) + np.sum(val_mask) + np.sum(test_mask)}")
        print(f"with a data split of {np.sum(train_mask)/len(x)}|{np.sum(val_mask)/len(x)}|{np.sum(test_mask)/len(x)}")

        data = Data(x=x,
                    edge_index=edge_idx,
                    y=y,
                    train_mask=torch.tensor(train_mask, dtype=torch.bool),
                    val_mask=torch.tensor(val_mask, dtype=torch.bool),
                    test_mask=torch.tensor(test_mask, dtype=torch.bool))

        if self.pre_filter is not None and not self.pre_filter(data):
           pass

        if self.pre_transform is not None:
            data = self.pre_transfrom(data)

        torch.save(data, osp.join(self.processed_dir, "data.pt"))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data.pt'))
        return data


class Flickr_v2(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, num_workers=32):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.num_workers = num_workers

    @property
    def raw_file_names(self):
        return ["edges.npy", "Features", "node_labels", "role"]

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def num_classes(self):
        path = osp.join(self.raw_dir, 'node_labels')
        node_labels = pd.read_csv(path, sep=',', names=['idx', 'label'])
        np_str_lab = node_labels[['label']].to_numpy().astype(str)
        return len(np.unique(np_str_lab))

    def download(self):
        raise NotImplementedError

    def load_features(self, filename):
        features = np.load(osp.join(self.raw_dir, 'Features', filename))
        return features

    def get_data(self):
        print("Loading Node Labels", end="...")
        path = osp.join(self.raw_dir, 'node_labels')
        node_labels = pd.read_csv(path, sep=',', names=['idx', 'label'])
        print("Done")

        print("Loading Edges", end="...")
        path = osp.join(self.raw_dir, 'edges.npy')
        edge_idx = np.load(path)
#         edge_index = pd.read_csv(path, sep=',', names=['source', 'target'])
        print("Done")

        attribute_files = []
        attr_idx = []
        for filename in tqdm(os.listdir(osp.join(self.raw_dir, 'Features')), desc="Loading Nodes"):
            idx = int(filename.split('.')[0])
            attr_idx.append(idx)
            attribute_files.append(filename)

        attributes = []
        with Pool(processes=32) as pool:
            attributes = list(tqdm(pool.imap(self.load_features, attribute_files), total=len(attribute_files), desc="Loading Node Features"))

        attributes = np.vstack(attributes)
        attr_idx = np.array(attr_idx)

        return attributes, attr_idx, edge_idx, node_labels

    def get_graph(self):
        attributes, attr_idx, edge_idx, node_labels = self.get_data()

        edge_idx = torch.tensor(edge_idx.transpose(), dtype=torch.long)
        map_dict = {v.item():i for i,v in enumerate(torch.unique(edge_idx))}

        node_idx = node_labels[['idx']].to_numpy()
        sort = np.in1d(node_idx, attr_idx)
        np_str_lab = node_labels[['label']].to_numpy().astype(str)

        node_idx = node_idx[sort]
        np_str_lab = np_str_lab[sort]
        classes, np_lab = np.unique(np_str_lab, return_inverse=True)

        # Converting readable data files into PyG data arrays
        map_attr = np.zeros_like(attr_idx)
        map_idx = np.zeros_like(node_idx)
        map_edge = torch.zeros_like(edge_idx)
        for k,v in tqdm(map_dict.items(), desc="Converting indices"):
            map_edge[edge_idx==k] = v
            map_idx[node_idx==k] = v
            map_attr[attr_idx==k] = v

        edge_idx = map_edge.long()
        labels = np_lab[np.argsort(map_idx, axis=0)]
        attrs = attributes[np.argsort(map_attr, axis=0)]

        print(edge_idx.numpy().shape)
        print(labels.shape)
        print(attrs.shape)

        x = torch.tensor(attrs, dtype=torch.float)
        y = torch.tensor(labels, dtype=torch.long).squeeze(-1)

        print(f"{len(attr_idx)} valid nodes")
        print(f"with {len(classes)} classes")
        print()

        return x, y, edge_idx, attr_idx

    def get_masks(self, attr_idx):
        print("Loading Masks...")
        with open(osp.join(self.raw_dir, 'role')) as file:
            role = json.load(file)

        print("Creating Masks", end="...")
        train, val, test = np.array(role['train']), np.array(role['val']), np.array(role['test'])

        train_mask = np.in1d(attr_idx, train)
        val_mask = np.in1d(attr_idx, val)
        test_mask = np.in1d(attr_idx, test)
        print("Done")

        return train_mask, val_mask, test_mask

    def process(self):
#         print("Loading Node Labels", end="...")
#         path = osp.join(self.raw_dir, 'node_labels')
#         node_labels = pd.read_csv(path, sep=',', names=['idx', 'label'])
#         print("Done")
# 
#         print("Loading Edges", end="...")
#         path = osp.join(self.raw_dir, 'edges')
#         edge_index = pd.read_csv(path, sep=',', names=['source', 'target'])
#         print("Done")
# 
#         attribute_files = []
#         attr_idx = []
#         for filename in tqdm(os.listdir(osp.join(self.raw_dir, 'Features')), desc="Loading Nodes"):
#             idx = int(filename.split('.')[0])
#             attr_idx.append(idx)
#             attribute_files.append(filename)
# 
#         attributes = []
#         with Pool(processes=32) as pool:
#             attributes = list(tqdm(pool.imap(self.load_features, attribute_files), total=len(attribute_files), desc="Loading Node Features"))
#
#         attributes = np.vstack(attributes)
#         attr_idx = np.array(attr_idx)
# 
#         edge_idx = torch.tensor(edge_index.to_numpy().transpose(), dtype=torch.long)
#         map_dict = {v.item():i for i,v in enumerate(torch.unique(edge_idx))}
# 
#         node_idx = node_labels[['idx']].to_numpy()
#         sort = np.in1d(node_idx, attr_idx)
#         np_str_lab = node_labels[['label']].to_numpy().astype(str)
# 
#         node_idx = node_idx[sort]
#         np_str_lab = np_str_lab[sort]
#         classes, np_lab = np.unique(np_str_lab, return_inverse=True)
# 
#         # Converting readable data files into PyG data arrays
#         map_attr = np.zeros_like(attr_idx)
#         map_idx = np.zeros_like(node_idx)
#         map_edge = torch.zeros_like(edge_idx)
#         for k,v in tqdm(map_dict.items(), desc="Converting indices"):
#             map_edge[edge_idx==k] = v
#             map_idx[node_idx==k] = v
#             map_attr[attr_idx==k] = v
# 
#         edge_idx = map_edge.long()
#         labels = np_lab[np.argsort(map_idx, axis=0)]
#         attrs = attributes[np.argsort(map_attr, axis=0)]
# 
#         print(edge_idx.numpy().shape)
#         print(labels.shape)
#         print(attrs.shape)

        x, y, edge_idx, attr_idx = self.get_graph()
        train_mask, val_mask, test_mask = self.get_masks(attr_idx)

#         print("Loading Masks...")
#         with open(osp.join(self.raw_dir, 'role')) as file:
#             role = json.load(file)
# 
#         print("Creating Masks", end="...")
#         train, val, test = np.array(role['train']), np.array(role['val']), np.array(role['test'])
# 
#         train_mask = np.in1d(attr_idx, train)
#         val_mask = np.in1d(attr_idx, val)
#         test_mask = np.in1d(attr_idx, test)
#         print("Done")

        print(f"{len(x)} processed attributes")
        print(f"with {len(y)} node labels")
        print(f"all masked to {np.sum(train_mask) + np.sum(val_mask) + np.sum(test_mask)}")
        print(f"with a data split of {np.sum(train_mask)/len(x)}|{np.sum(val_mask)/len(x)}|{np.sum(test_mask)/len(x)}")

        data = Data(x=x,
                    edge_index=edge_idx,
                    y=y,
                    train_mask=torch.tensor(train_mask, dtype=torch.bool),
                    val_mask=torch.tensor(val_mask, dtype=torch.bool),
                    test_mask=torch.tensor(test_mask, dtype=torch.bool))

        if self.pre_filter is not None and not self.pre_filter(data):
           pass

        if self.pre_transform is not None:
            data = self.pre_transfrom(data)

        torch.save(data, osp.join(self.processed_dir, "data.pt"))

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data.pt'))
        return data


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', required=True, help="Root directory for dataset")
    args = parser.parse_args()

    dataset = Amazon(root=args.root)
    print(dataset.num_features)
    print(dataset[0].num_nodes)
    print(dataset.num_classes)

