import os
import torch
import pickle
import logging
import pandas as pd
import os.path as osp
from torch_geometric.data import InMemoryDataset
from sklearn.model_selection import train_test_split

from .data_utils import read_graph_list, nx_to_graph_data_obj_unlabeled
logger = logging.getLogger(__name__)

class UnlabelGraphDataset(InMemoryDataset):
    def __init__(self, name='QM9', root ='data', transform=None, pre_transform = None):
        '''
            - name (str): name of the dataset: plym-oxygen/melting/glass/density
            - root (str): root directory to store the dataset folder
            - transform, pre_transform (optional): transform/pre-transform graph objects
        ''' 
        if name[-2] == 'D':
            max_times = int(name[-1])
        else:
            max_times = 1
        self.name = name
        self.dir_name = '_'.join(name.split('-'))
        self.original_root = root
        self.root = osp.join(root, self.dir_name)
        self.processed_root = osp.join(osp.abspath(self.root))

        self.num_tasks = 1
        self.eval_metric = 'none'
        self.task_type = 'unlabel'
        self.__num_classes__ = '-1'
        self.binary = 'False'
        self.max_to_use = 133051 * max_times

        super(UnlabelGraphDataset, self).__init__(self.processed_root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        if name[-2] == 'D':
            with open(osp.join(self.original_root,'{}.txt'.format(self.name[:-1]))) as fp:
                num_lines = sum(1 for line in fp if line.rstrip())
        else:
            with open(osp.join(self.original_root,'{}.txt'.format(self.name))) as fp:
                num_lines = sum(1 for line in fp if line.rstrip())
        self.total_data_len = num_lines
        print('# Unlabel total: {}, max to use {}'.format(self.total_data_len, self.max_to_use))
    
    def get_unlabeled_idx(self):
        return torch.arange(0, min(self.max_to_use, self.total_data_len), dtype=torch.long)

    @property
    def processed_file_names(self):
        return ['geometric_data_processed.pt']

    def process(self):
        if self.name[-2] == 'D':
            file_name = '{}.txt'.format(self.name[:-1])
            print('file_name: {}'.format(file_name))
        else:
            file_name = '{}.txt'.format(self.name)
        print('begin processing unlabeled data at folder: ' , osp.join(self.original_root,file_name))
        data_list = read_graph_list(osp.join(self.original_root,file_name), property_name=self.name, process_labeled=False)
        self.total_data_len = len(data_list)
        print('Unlabeled data length', self.total_data_len)
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])



class UnlabelPPI(InMemoryDataset):
    def __init__(self, name='PPI', root ='data', transform=None, pre_transform = None):
        '''
            - name (str): name of the dataset
            - root (str): root directory to store the dataset folder
            - transform, pre_transform (optional): transform/pre-transform graph objects
        ''' 
        self.name = name
        # self.dir_name = '_'.join(name.split('-'))
        self.dir_name = 'nx_ppi'
        self.original_root = root
        self.root = osp.join(root, self.dir_name)
        self.processed_root = osp.join(osp.abspath(self.root))

        self.num_tasks = 1
        self.eval_metric = 'none'
        self.task_type = 'unlabel'
        self.__num_classes__ = '-1'
        self.binary = 'False'
        self.max_to_use = 300000

        super(UnlabelPPI, self).__init__(self.processed_root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        self.data.x = self.data.x.to(dtype=torch.long)
        self.data.edge_attr = self.data.edge_attr.to(dtype=torch.long)
        
        with open(osp.join(self.root, 'raw','unlabeled.txt'), 'r') as f:
            self.total_data_len = int(f.readline())
            
        print('# Unlabel total: {}, max to use {}'.format(self.total_data_len, self.max_to_use))
    
    def get_unlabeled_idx(self):
        return torch.arange(0, min(self.max_to_use, self.total_data_len), dtype=torch.long)

    @property
    def processed_file_names(self):
        if 'gen' in self.name:
            return ['geometric_data_processed_gen.pt']
        return ['geometric_data_processed_unlabeled.pt']

    def process(self):
        load_file_name = osp.join(self.root, 'raw','unlabeled.pkl')
        if 'gen' in self.name:
            load_file_name = osp.join(self.root, 'raw','/afs/crc.nd.edu/group/dmsquare/vol2/gliu7/developing/G2Aug_dev/GDSS/samples/pkl/ppi_all/test/ppi_denoise-sample.pkl')
        print('begin processing graph data at folder: ' , load_file_name)
        data_list = []

        with open(load_file_name, "rb") as f:
            nx_objects = pickle.load(f)
        for nxgraph in nx_objects:
            data_list.append(nx_to_graph_data_obj_unlabeled(nxgraph))

        print(data_list[:3])
        self.labeled_data_len = len(data_list)
        print('Labeled Finished with length ', self.labeled_data_len)
        with open(osp.join(self.root, 'raw','unlabeled.txt'), 'w') as f:
            f.write('%d' % self.labeled_data_len)

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])