import json
import os.path as osp
import pickle as pkl
import numpy as np
import torch

from dataprocess.base_data import Graph
from dataprocess.base_dataset import NodeDataset
from dataprocess.utils import pkl_read_file

class UnDirDataset(NodeDataset):
    def __init__(self, args, name, root, k):
        super(UnDirDataset, self).__init__(root, name, k)
        self.read_file()
        self.h = None
        total_idx = np.arange(0, self.num_node)
        total_idx = torch.from_numpy(total_idx)
        self.total_idx = total_idx

        if self.name == "cora":
            self.train_idx, self.val_idx, self.test_idx = self.data_split(0.6, 0.2, 0.2)
            
            # self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            # self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            # self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            
            self.class_name = ['Case Based', 'Genetic Algorithms', 'Neural Networks', 'Probabilistic Methods', 'Reinforcement Learning', 'Rule Learning', 'Theory']
        elif self.name == "citeseer":
            # self.train_idx, self.val_idx, self.test_idx = self.data_split(0.6, 0.2, 0.2)
            
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            
            self.class_name = ['Agents', 'Machine Learning', 'Information Retrieval', 'Database', 'Human-computer Interaction', 'Artificial Intelligence']
        elif self.name == 'pubmed':  
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            
            self.class_name = ['Diabetes Mellitus Experimental', 'Diabetes Mellitus Type 1', 'Diabetes Mellitus Type 2']
        elif self.name == 'history':
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            self.class_name = ['World', 'Americas', 'Asia', 'Military', 'Europe', 'Russia', 'Africa', 'Ancient Civilizations', 'Middle East', 'Historical Study & Educational Resources', 'Australia & Oceania', 'Arctic & Antarctica']
            
        elif self.name == 'photo':
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            self.class_name = ['Film Photography', 'Video', 'Digital Cameras', 'Accessories', 'Binoculars & Scopes', 'Lenses', 'Bags & Cases', 'Lighting & Studio', 'Flashes', 'Tripods & Monopods', 'Underwater Photography', 'Video Surveillance']
        
        elif self.name == 'children':
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            self.class_name = ['Literature & Fiction', 'Animals', 'Growing Up & Facts of Life', 'Fairy Tales Folk Tales & Myths', 'Activities Crafts & Games', 'Action & Adventure', 'Geography & Cultures', 'Education & Reference', 'Arts Music & Photography', 'Holidays & Celebrations', 'Science Nature & How It Works', 'Biographies']
        
        elif self.name == 'amazonratings':
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            self.class_name = ['5 score', '4.5 score', '4 score', '3.5 score', '0-3 score']
            
        elif self.name in ['instagram', 'reddit']:
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            
            if self.name == 'instagram':
                self.class_name = ['Commercial User', 'Normal User']
            elif self.name == 'reddit':
                self.class_name = ['Popular User', 'Unpopular User']
        elif self.name == 'wikics':
            self.train_idx = np.load(osp.join(self.raw_dir, 'train_idx.npy'))
            self.val_idx = np.load(osp.join(self.raw_dir, 'val_idx.npy'))
            self.test_idx = np.load(osp.join(self.raw_dir, 'test_idx.npy'))
            self.class_name = ['Computational Linguistics', 'Databases', 'Operating Systems', 'Computer Architecture', 'Computer Security', 'Internet Protocols', 'Computer File Systems', 'DistributComputing Architecture', 'Web Technology', 'Programming Language Topics']
        else:
            self.train_idx, self.val_idx, self.test_idx = self.data_split(0.6, 0.2, 0.2)
            self.class_name = []


    @property
    def raw_file_paths(self):
        return [osp.join(self.raw_dir, "ind." + self.name)]

    @property
    def processed_file_paths(self):
        return osp.join(self.processed_dir, f"{self.name}.graph")

    def download(self):
        dataset_drive_url = {}
    
    def data_split(self, train_ratio, val_ratio, test_ratio):
        indices = np.arange(self.x.shape[0])
        np.random.shuffle(indices)
        train_split_index = int(train_ratio * len(self.x))
        val_split_index = int((train_ratio + val_ratio) * len(self.x))
        test_split_index = int((train_ratio + val_ratio + test_ratio) * len(self.x))
        train_idx= indices[:train_split_index]
        val_idx = indices[train_split_index:val_split_index]
        test_idx = indices[val_split_index:test_split_index]
        return train_idx, val_idx, test_idx

    def read_file(self):
        self.data = pkl_read_file(self.processed_file_paths)
        self.edge = self.data.edge
        self.node = self.data.node
        self.x = self.data.x
        self.y = self.data.y
        self.adj = self.data.adj
        self.num_features = self.data.num_features
        self.num_classes = self.data.num_classes
        self.num_targets = self.data.num_targets
        self.num_node = self.data.num_node
        self.num_edge = self.data.num_edge

    def process(self):
        if self.name in ['cora', 'citeseer', 'pubmed', 'arxiv2023', 'instagram', 'reddit', 'wikics', 'history', 'photo', 'children', 'amazonratings']:
            features = np.load(osp.join(self.raw_dir, 'x.npy'))
            edge_index = np.load(osp.join(self.raw_dir, 'edge_index.npy'))
            labels = np.load(osp.join(self.raw_dir, 'y.npy'))
            features = torch.FloatTensor(features)
            labels = torch.LongTensor(labels).long()
            edge_index = torch.tensor(edge_index)
            row, col = edge_index
            edge_weight = torch.ones(len(row))
            num_node = features.shape[0]
            
        else:
            labels = np.load(osp.join(self.raw_dir, 'y.npy'))
            edge_index = np.load(osp.join(self.raw_dir, 'edge_index.npy'))
            features = torch.FloatTensor([[1,2], [3,4]])
            labels = torch.LongTensor(labels).long()
            edge_index = torch.tensor(edge_index).t()
            row, col = edge_index
            edge_weight = torch.ones(len(row))
            num_node = len(labels)

        g = Graph(row, col, edge_weight, num_node, x=features, y=labels)

        with open(self.processed_file_paths, 'wb') as rf:
            try:
                pkl.dump(g, rf)
            except IOError as e:
                print(e)
                exit(1)