
import os
from itertools import repeat

import numpy as np
import torch
from torch_geometric.utils import subgraph, to_networkx
from torch_geometric.data import Data, InMemoryDataset, Batch

from os.path import join
import copy


class MoleculeDataset(InMemoryDataset):
    def __init__(self, root, dataset='zinc250k', transform=None,
                 pre_transform=None, pre_filter=None, empty=False):

        self.root = root
        self.dataset = dataset
        self.transform = transform
        self.pre_filter = pre_filter
        self.pre_transform = pre_transform

        super(MoleculeDataset, self).__init__(root, transform, pre_transform, pre_filter)

        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])
        print('Dataset: {}\nData: {}'.format(self.dataset, self.data))

    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            if key == 'dihedral_angle_value':
                continue
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
            data[key] = item[s]
        return data

    @property
    def raw_file_names(self):
        if self.dataset == 'davis':
            file_name_list = ['davis']
        elif self.dataset == 'kiba':
            file_name_list = ['kiba']
        else:
            file_name_list = os.listdir(self.raw_dir)
        return file_name_list

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        return


class Molecule3DDatasetFragRandomaug3d_2(MoleculeDataset):

    def __init__(self, root, n_mol,choose = 0,transform=None, seed=777,
                 pre_transform=None, pre_filter=None, empty=False, **kwargs):
        os.makedirs(root, exist_ok=True)
        os.makedirs(join(root, 'raw'), exist_ok=True)
        os.makedirs(join(root, 'processed'), exist_ok=True)
        if 'smiles_copy_from_3D_file' in kwargs:  # for 2D Datasets (SMILES)
            self.smiles_copy_from_3D_file = kwargs['smiles_copy_from_3D_file']
        else:
            self.smiles_copy_from_3D_file = None

        self.root, self.seed = root, seed
        self.n_mol = n_mol
        self.pre_transform, self.pre_filter = pre_transform, pre_filter
        self.aug_prob = None
        self.aug_mode = 'no_aug'
        self.aug_strength = 0.2
        self.choose = choose
        self.choosetwo_idx = [[0,1], [0,2],[0,3], [1,2], [1,3], [2,3]]
        self.choosethree_idx = [[0,1,2], [0,1,3], [0,2,3], [1,2,3]]
        print("I choose")
        print(choose)
        self.augmentations = [self.node_drop, self.subgraph,
                              self.edge_pert, self.attr_mask, lambda x: x]
        super(Molecule3DDatasetFragRandomaug3d_2, self).__init__(
            root, transform, pre_transform, pre_filter)

        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])
#             for i in range(50000):
#                 brics_data = torch.load(self.processed_paths[0] + "_" + str(i))
#                 if brics_data[1] == None:
#                     self.brics.append(brics_data[0])
#                     self.brics_slice.append(dict({'x': None, 'edge_index': None, 'edge_attr': None}))
#                 else:
#                     self.brics.append(brics_data[0])
#                     self.brics_slice.append(brics_data[1])

        
        print('root: {},\ndata: {},\nn_mol: {},\n'.format(
            self.root, self.data, self.n_mol))


    def set_augMode(self, aug_mode):
        self.aug_mode = aug_mode

    def set_augStrength(self, aug_strength):
        self.aug_strength = aug_strength

    def set_augProb(self, aug_prob):
        self.aug_prob = aug_prob

    def node_drop(self, data):
        #print(data.x.size())
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        drop_num = int(node_num * self.aug_strength)

        idx_perm = np.random.permutation(node_num)
        idx_nodrop = idx_perm[drop_num:].tolist()
        idx_nodrop.sort()

        edge_idx, edge_attr = subgraph(subset=idx_nodrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num)

        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nodrop]
        data.__num_nodes__, _ = data.x.shape
        return data

    def edge_pert(self, data):
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        pert_num = int(edge_num * self.aug_strength)

        # delete edges
        idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
                                    replace=False)
        edge_index = data.edge_index[:, idx_drop]
        edge_attr = data.edge_attr[idx_drop]

        # add edges
        adj = torch.ones((node_num, node_num))
        adj[edge_index[0], edge_index[1]] = 0
        # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
        edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
        idx_add = np.random.choice(edge_index_nonexist.shape[1],
                                   pert_num, replace=False)
        edge_index_add = edge_index_nonexist[:, idx_add]
        # random 4-class & 3-class edge_attr for 1st & 2nd dimension
        edge_attr_add_1 = torch.tensor(np.random.randint(
            4, size=(edge_index_add.shape[1], 1)))
        edge_attr_add_2 = torch.tensor(np.random.randint(
            3, size=(edge_index_add.shape[1], 1)))
        edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
        edge_index = torch.cat((edge_index, edge_index_add), dim=1)
        edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        
        return data

    def edge_del(self, data):
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        pert_num = int(edge_num * self.aug_strength)

        # delete edges
        idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
                                    replace=False)
        edge_index = data.edge_index[:, idx_drop]
        edge_attr = data.edge_attr[idx_drop]

        # add edges
        #adj = torch.ones((node_num, node_num))
        #adj[edge_index[0], edge_index[1]] = 0
        # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
        #edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
        #idx_add = np.random.choice(edge_index_nonexist.shape[1],
        #                           pert_num, replace=False)
        #edge_index_add = edge_index_nonexist[:, idx_add]
        # random 4-class & 3-class edge_attr for 1st & 2nd dimension
        #edge_attr_add_1 = torch.tensor(np.random.randint(
        #    4, size=(edge_index_add.shape[1], 1)))
        #edge_attr_add_2 = torch.tensor(np.random.randint(
        #    3, size=(edge_index_add.shape[1], 1)))
        #edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
        #edge_index = torch.cat((edge_index, edge_index_add), dim=1)
        #edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        return data


    def attr_mask(self, data):

        _x = data.x.clone()
        node_num, _ = data.x.size()
        mask_num = int(node_num * self.aug_strength)

        token = data.x.float().mean(dim=0).long()
        idx_mask = np.random.choice(
            node_num, mask_num, replace=False)

        _x[idx_mask] = token
        data.x = _x
        return data

    def subgraph(self, data):

        G = to_networkx(data)
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        sub_num = int(node_num * (1 - self.aug_strength))

        idx_sub = [np.random.randint(node_num, size=1)[0]]
        idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])

        while len(idx_sub) <= sub_num:
            if len(idx_neigh) == 0:
                idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
                idx_neigh = set([np.random.choice(idx_unsub)])
            sample_node = np.random.choice(list(idx_neigh))

            idx_sub.append(sample_node)
            idx_neigh = idx_neigh.union(
                set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))

        idx_nondrop = idx_sub
        idx_nondrop.sort()

        edge_idx, edge_attr = subgraph(subset=idx_nondrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num)

        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nondrop]
        data.__num_nodes__, _ = data.x.shape
        return data


    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx+1])
            data[key] = item[s]

        load_brics = torch.load(self.processed_paths[0]+"_"+str(idx))

        if load_brics[1] == None:
            brics_datas = [Data()]
            for key in load_brics[0].keys:
                item, slices = self.data[key], self.slices[key]
                s = list(repeat(slice(None), item.dim()))
                s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx+1])
                brics_datas[0][key] = item[s] 

        else:
            brics_datas = [Data() for _ in range(len(load_brics[1]['x'])-1)]
            for i in range(len(load_brics[1]['x'])-1):
                for key in load_brics[0].keys:
                    item, slices = load_brics[0][key], load_brics[1][key]
                    s = list(repeat(slice(None), item.dim()))
                    s[data.__cat_dim__(key, item)] = slice(slices[i], slices[i+1])
                    brics_datas[i][key] = item[s]
                    
        if self.aug_mode == 'no_aug':
            n_aug, n_aug1, n_aug2 = 4, 4, 4
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'choose':
            two_augmentations = [self.augmentations[self.choose], lambda x: x]
            n_aug = np.random.choice(2,1)[0]
            n_aug1 = np.random.choice(2,1)[0]
            n_aug2 = np.random.choice(2,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'edgedel':
            two_augmentations = [self.edge_del, lambda x: x]
            n_aug = np.random.choice(2,1)[0]
            n_aug1 = np.random.choice(2,1)[0]
            n_aug2 = np.random.choice(2,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'choosetwo':
            if self.choose == 0:
                two_augmentations = [self.augmentations[3], self.augmentations[0], lambda x:x]
            elif self.choose == 1:
                two_augmentations = [self.augmentations[1], self.augmentations[2], lambda x:x]
                
            n_aug = np.random.choice(3,1)[0]
            aug_data = two_augmentations[n_aug](data.clone())
            aug_brics_datas = copy.deepcopy(brics_datas)
            for i in range(len(brics_datas)):
                n_aug = np.random.choice(3,1)[0]
                aug_brics_datas[i] = two_augmentations[n_aug](aug_brics_datas[i].clone())
            
            
            return aug_data, Batch.from_data_list(aug_brics_datas), data, Batch.from_data_list(brics_datas), len(brics_datas)

        elif self.aug_mode == 'choosethree':
            three_augmentations = [self.augmentations[self.choosethree_idx[self.choose][0]], self.augmentations[self.choosethree_idx[self.choose][1]], self.augmentations[self.choosethree_idx[self.choose][2]]]
            
            n_aug = np.random.choice(3,1)[0]
            n_aug1 = np.random.choice(3,1)[0]
            n_aug2 = np.random.choice(3,1)[0]
            data = three_augmentations[n_aug](data)
            data1 = three_augmentations[n_aug1](data1)
            data2 = three_augmentations[n_aug2](data2)
        elif self.aug_mode == 'uniform':
            n_aug_init = np.random.choice(25, 1)[0]
            n_aug1, n_aug2 = n_aug_init // 5, n_aug_init % 5
            n_aug = np.random.choice(5,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'sample':
            n_aug_init = np.random.choice(25, 1, p=self.aug_prob)[0]
            n_aug1, n_aug2 =  n_aug_init // 5, n_aug_init % 5
            n_aug = np.random.choice(5,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        else:
            raise ValueError
        return data, data1, data2

    @property
    def raw_file_names(self):
        return os.listdir(self.raw_dir)

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        return



class Molecule3DDatasetDenoising(MoleculeDataset):

    def __init__(self, root, n_mol,choose = 0,transform=None, seed=777,
                 pre_transform=None, pre_filter=None, empty=False, **kwargs):
        os.makedirs(root, exist_ok=True)
        os.makedirs(join(root, 'raw'), exist_ok=True)
        os.makedirs(join(root, 'processed'), exist_ok=True)
        if 'smiles_copy_from_3D_file' in kwargs:  # for 2D Datasets (SMILES)
            self.smiles_copy_from_3D_file = kwargs['smiles_copy_from_3D_file']
        else:
            self.smiles_copy_from_3D_file = None

        self.root, self.seed = root, seed
        self.n_mol = n_mol
        self.pre_transform, self.pre_filter = pre_transform, pre_filter
        self.aug_prob = None
        self.aug_mode = 'no_aug'
        self.aug_strength = 0.2
        self.choose = choose
        self.choosetwo_idx = [[0,1], [0,2],[0,3], [1,2], [1,3], [2,3]]
        self.choosethree_idx = [[0,1,2], [0,1,3], [0,2,3], [1,2,3]]
        print("I choose")
        print(choose)
        self.augmentations = [self.node_drop, self.subgraph,
                              self.edge_pert, self.attr_mask, lambda x: x]
        super(Molecule3DDatasetDenoising, self).__init__(
            root, transform, pre_transform, pre_filter)

        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])
#             for i in range(50000):
#                 brics_data = torch.load(self.processed_paths[0] + "_" + str(i))
#                 if brics_data[1] == None:
#                     self.brics.append(brics_data[0])
#                     self.brics_slice.append(dict({'x': None, 'edge_index': None, 'edge_attr': None}))
#                 else:
#                     self.brics.append(brics_data[0])
#                     self.brics_slice.append(brics_data[1])

        
        print('root: {},\ndata: {},\nn_mol: {},\n'.format(
            self.root, self.data, self.n_mol))


    def set_augMode(self, aug_mode):
        self.aug_mode = aug_mode

    def set_augStrength(self, aug_strength):
        self.aug_strength = aug_strength

    def set_augProb(self, aug_prob):
        self.aug_prob = aug_prob

    def node_drop(self, data):
        #print(data.x.size())
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        drop_num = int(node_num * self.aug_strength)

        idx_perm = np.random.permutation(node_num)
        idx_nodrop = idx_perm[drop_num:].tolist()
        idx_nodrop.sort()

        edge_idx, edge_attr = subgraph(subset=idx_nodrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num)

        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nodrop]
        data.__num_nodes__, _ = data.x.shape
        return data

    def edge_pert(self, data):
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        pert_num = int(edge_num * self.aug_strength)

        # delete edges
        idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
                                    replace=False)
        edge_index = data.edge_index[:, idx_drop]
        edge_attr = data.edge_attr[idx_drop]

        # add edges
        adj = torch.ones((node_num, node_num))
        adj[edge_index[0], edge_index[1]] = 0
        # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
        edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
        idx_add = np.random.choice(edge_index_nonexist.shape[1],
                                   pert_num, replace=False)
        edge_index_add = edge_index_nonexist[:, idx_add]
        # random 4-class & 3-class edge_attr for 1st & 2nd dimension
        edge_attr_add_1 = torch.tensor(np.random.randint(
            4, size=(edge_index_add.shape[1], 1)))
        edge_attr_add_2 = torch.tensor(np.random.randint(
            3, size=(edge_index_add.shape[1], 1)))
        edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
        edge_index = torch.cat((edge_index, edge_index_add), dim=1)
        edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        
        return data

    def edge_del(self, data):
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        pert_num = int(edge_num * self.aug_strength)

        # delete edges
        idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
                                    replace=False)
        edge_index = data.edge_index[:, idx_drop]
        edge_attr = data.edge_attr[idx_drop]

        # add edges
        #adj = torch.ones((node_num, node_num))
        #adj[edge_index[0], edge_index[1]] = 0
        # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
        #edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
        #idx_add = np.random.choice(edge_index_nonexist.shape[1],
        #                           pert_num, replace=False)
        #edge_index_add = edge_index_nonexist[:, idx_add]
        # random 4-class & 3-class edge_attr for 1st & 2nd dimension
        #edge_attr_add_1 = torch.tensor(np.random.randint(
        #    4, size=(edge_index_add.shape[1], 1)))
        #edge_attr_add_2 = torch.tensor(np.random.randint(
        #    3, size=(edge_index_add.shape[1], 1)))
        #edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
        #edge_index = torch.cat((edge_index, edge_index_add), dim=1)
        #edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        return data


    def attr_mask(self, data):

        _x = data.x.clone()
        node_num, _ = data.x.size()
        mask_num = int(node_num * self.aug_strength)

        token = data.x.float().mean(dim=0).long()
        idx_mask = np.random.choice(
            node_num, mask_num, replace=False)

        _x[idx_mask] = token
        data.x = _x
        return data

    def subgraph(self, data):

        G = to_networkx(data)
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        sub_num = int(node_num * (1 - self.aug_strength))

        idx_sub = [np.random.randint(node_num, size=1)[0]]
        idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])

        while len(idx_sub) <= sub_num:
            if len(idx_neigh) == 0:
                idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
                idx_neigh = set([np.random.choice(idx_unsub)])
            sample_node = np.random.choice(list(idx_neigh))

            idx_sub.append(sample_node)
            idx_neigh = idx_neigh.union(
                set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))

        idx_nondrop = idx_sub
        idx_nondrop.sort()

        edge_idx, edge_attr = subgraph(subset=idx_nondrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num)

        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nondrop]
        data.__num_nodes__, _ = data.x.shape
        return data


    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            if key in ["x", "positions", "batch", "id", "mol_id"]:
                item, slices = self.data[key], self.slices[key]
                s = list(repeat(slice(None), item.dim()))
                s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx+1])
                data[key] = item[s]

                    
        if self.aug_mode == 'no_aug':
            return data
            n_aug, n_aug1, n_aug2 = 4, 4, 4
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'choose':
            two_augmentations = [self.augmentations[self.choose], lambda x: x]
            n_aug = np.random.choice(2,1)[0]
            n_aug1 = np.random.choice(2,1)[0]
            n_aug2 = np.random.choice(2,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'edgedel':
            two_augmentations = [self.edge_del, lambda x: x]
            n_aug = np.random.choice(2,1)[0]
            n_aug1 = np.random.choice(2,1)[0]
            n_aug2 = np.random.choice(2,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'choosetwo':
            if self.choose == 0:
                two_augmentations = [self.augmentations[3], self.augmentations[0], lambda x:x]
            elif self.choose == 1:
                two_augmentations = [self.augmentations[1], self.augmentations[2], lambda x:x]
                
            n_aug = np.random.choice(3,1)[0]
            aug_data = two_augmentations[n_aug](data.clone())
            
            
            return data

        elif self.aug_mode == 'choosethree':
            three_augmentations = [self.augmentations[self.choosethree_idx[self.choose][0]], self.augmentations[self.choosethree_idx[self.choose][1]], self.augmentations[self.choosethree_idx[self.choose][2]]]
            
            n_aug = np.random.choice(3,1)[0]
            n_aug1 = np.random.choice(3,1)[0]
            n_aug2 = np.random.choice(3,1)[0]
            data = three_augmentations[n_aug](data)
            data1 = three_augmentations[n_aug1](data1)
            data2 = three_augmentations[n_aug2](data2)
        elif self.aug_mode == 'uniform':
            n_aug_init = np.random.choice(25, 1)[0]
            n_aug1, n_aug2 = n_aug_init // 5, n_aug_init % 5
            n_aug = np.random.choice(5,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'sample':
            n_aug_init = np.random.choice(25, 1, p=self.aug_prob)[0]
            n_aug1, n_aug2 =  n_aug_init // 5, n_aug_init % 5
            n_aug = np.random.choice(5,1)[0]
            data = self.augmentations[n_aug](data)
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        else:
            raise ValueError
        return data, data1, data2

    @property
    def raw_file_names(self):
        return os.listdir(self.raw_dir)

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        return
