import os
import torch
import pickle
import logging
import pandas as pd
import os.path as osp
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset
from sklearn.model_selection import train_test_split

from .data_utils import nx_to_graph_data_obj, nx_to_graph_data_obj_with_edge_attr
logger = logging.getLogger(__name__)

class GenericGraphFromNetworkx(InMemoryDataset):
    def __init__(self, name='nx-ppi', root ='data', transform=None, pre_transform = None):
        '''
            - name (str): name of the dataset: plym-oxygen/melting/glass/density
            - root (str): root directory to store the dataset folder
            - transform, pre_transform (optional): transform/pre-transform graph objects
        ''' 
        self.name = name
        self.dir_name = '_'.join(name.split('-'))
        self.original_root = root
        self.root = osp.join(root, self.dir_name)
        self.processed_root = osp.join(osp.abspath(self.root))

        self.num_tasks = 40
        self.eval_metric = 'rocauc'
        self.task_type = 'classification'
        self.__num_classes__ = '-1'
        self.binary = 'False'

        super(GenericGraphFromNetworkx, self).__init__(self.processed_root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        self.total_data_len = self.data.y.size(0)
        self.data.x = self.data.x.to(dtype=torch.long)
        self.data.edge_attr = self.data.edge_attr.to(dtype=torch.long)

    def get_idx_split(self, split_type = 'random'):
        if split_type is None:
            split_type = 'random'
        path = osp.join(self.root, 'split', split_type)
        if not os.path.exists(path):
            os.makedirs(path)
        try: 
            train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0]
            valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0]
            test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0]
        except:
            print('Splitting with species and random seed 42')
            full_idx = torch.arange(self.total_data_len)
            train_valid_species_id_list=[3702, 6239, 511145,7227, 10090, 4932, 7955]
            test_species_id_list=[9606]
            train_valid_byte_tensor = torch.zeros(self.total_data_len, dtype=torch.uint8)
            for id in train_valid_species_id_list:
                train_valid_byte_tensor += (self.data.species_id == id)
            test_species_byte_tensor = torch.zeros(self.total_data_len, dtype=torch.uint8)
            for id in test_species_id_list:
                test_species_byte_tensor += (self.data.species_id == id)
            assert ((train_valid_byte_tensor + test_species_byte_tensor) == 1).all()
            train_valid_idx = full_idx[train_valid_byte_tensor.bool()]
            test_idx = full_idx[test_species_byte_tensor.bool()]
            # frac_train, frac_valid, frac_test, seed = 0.85, 0.15, 0, 42
            train_idx, valid_idx, _ = random_split(train_valid_idx, seed = 42, frac_train=0.85, frac_valid=0.15, frac_test=0)
            _, test_idx, _ = random_split(test_idx, seed = 42, frac_train=0.5, frac_valid=0.5, frac_test=0)
            df_train = pd.DataFrame({'train': train_idx})
            df_valid = pd.DataFrame({'valid': valid_idx})
            df_test = pd.DataFrame({'test': test_idx})
            df_train.to_csv(osp.join(path, 'train.csv.gz'), index=False, header=False, compression="gzip")
            df_valid.to_csv(osp.join(path, 'valid.csv.gz'), index=False, header=False, compression="gzip")
            df_test.to_csv(osp.join(path, 'test.csv.gz'), index=False, header=False, compression="gzip")

        return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)}            
            
    @property
    def processed_file_names(self):
        return ['geometric_data_processed.pt']

    def process(self):
        load_file_name = osp.join(self.root, 'raw','supervised.pkl')
        print('begin processing graph data at folder: ' , load_file_name)
        data_list = []

        with open(load_file_name, "rb") as f:
            nx_objects = pickle.load(f)
        if self.name == 'nx-ppiall':
            for nxgraph in tqdm(nx_objects):
                data_list.append(nx_to_graph_data_obj_with_edge_attr(nxgraph))
        else:
            for nxgraph in tqdm(nx_objects):
                data_list.append(nx_to_graph_data_obj(nxgraph))

        print(data_list[:3])
        self.total_data_len = len(data_list)
        print('Labeled Finished with length ', self.total_data_len)
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


import random
import numpy as np
def random_split(dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1,
                 seed=0):
    """
    Adapted from graph-pretrain
    :param dataset:
    :param task_idx:
    :param null_value:
    :param frac_train:
    :param frac_valid:
    :param frac_test:
    :param seed:
    :return: train, valid, test slices of the input dataset obj.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    num_mols = len(dataset)
    random.seed(seed)
    all_idx = list(range(num_mols))
    random.shuffle(all_idx)

    train_idx = all_idx[:int(frac_train * num_mols)]
    valid_idx = all_idx[int(frac_train * num_mols):int(frac_valid * num_mols)
                                                   + int(frac_train * num_mols)]
    test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols):]

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(valid_idx).intersection(set(test_idx))) == 0
    assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols

    train_dataset = dataset[torch.tensor(train_idx)]
    valid_dataset = dataset[torch.tensor(valid_idx)]
    if frac_test == 0:
        test_dataset = None
    else:
        test_dataset = dataset[torch.tensor(test_idx)]

    return train_dataset, valid_dataset, test_dataset