from lib2to3.pgen2.pgen import DFAState
import os
from time import time
import numpy as np
import networkx as nx
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset
import json
from utils.graph_utils import quantize_mol_tensor


def load_mol(filepath):
    print(f'Loading file {filepath}')
    if not os.path.exists(filepath):
        raise ValueError(f'Invalid filepath {filepath} for dataset')
    load_data = np.load(filepath)
    result = []  # a tuple of lists ((133885), (133885), (133885))
    i = 0
    while True:
        key = f'arr_{i}'
        if key in load_data.keys():
            result.append(load_data[key])
            i += 1
        else:
            break
    # convert a tuple of lists to a list of tuples
    return list(map(lambda x, a: (x, a), result[0], result[1]))


class MolDataset(Dataset):
    def __init__(self, mols, transform):
        '''
        mols: a list of tuples. Each tuple: ((9,), (4, 9, 9))
        '''
        self.mols = mols
        self.transform = transform

    def __len__(self):
        return len(self.mols)

    def __getitem__(self, idx):
        return self.transform(self.mols[idx])


class MolDataset_prop(Dataset):
    def __init__(self, mols, transform, prop, idx):
        self.mols = mols
        self.transform = transform
        df = pd.read_csv('data/zinc250k.csv').iloc[idx]

        if 'parp1' in prop: protein = 'parp1'
        elif 'fa7' in prop: protein = 'fa7'
        elif '5ht1b' in prop: protein = '5ht1b'
        elif 'jak2' in prop: protein = 'jak2'
        elif 'braf' in prop: protein = 'braf'
        elif 'tgfr1' in prop: protein = 'tgfr1'
        self.y = df[protein]
        self.y /= 20.0

        if 'qed' in prop:
            self.y *= df['qed']
        if 'sa' in prop:
            self.y *= df['sa']
            
    def __len__(self):
        return len(self.mols)

    def __getitem__(self, idx):
        return (*self.transform(self.mols[idx]), self.y.iloc[idx])


def dataloader(config, get_graph_list=False, prop=None, layershare=False):
    start_time = time()
    
    if config.data.data == 'QM9':
        def transform_RGCN(data):
            x, adj = data
            # the last place is for virtual nodes
            # 6: C, 7: N, 8: O, 9: F
            x_ = np.zeros((9, 5))
            indices = np.where(x >= 6, x - 6, 4)
            x_[np.arange(9), indices] = 1
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (9, 5), (4, 9, 9)

    elif config.data.data == 'ZINC250k':
        def transform_RGCN(data):
            x, adj = data
            # the last place is for virtual nodes
            # 6: C, 7: N, 8: O, 9: F, 15: P, 16: S, 17: Cl, 35: Br, 53: I
            zinc250k_atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
            x_ = np.zeros((38, 10), dtype=np.float32)
            for i in range(38):
                ind = zinc250k_atomic_num_list.index(x[i])
                x_[i, ind] = 1.
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (38, 10), (4, 38, 38)
    
    elif config.data.data == 'MOSES':
        def transform_RGCN(data):
            x, adj = data
            moses_atomic_num_list = [6, 7, 8, 9, 16, 17, 35, 0]
            x_ = np.zeros((27, 8), dtype=np.float32)
            for i in range(27):
                ind = moses_atomic_num_list.index(x[i])
                x_[i, ind] = 1.
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (27, 8), (4, 27, 27)

    def transform_GCN(data):
        x, adj = transform_RGCN(data)
        x = x[:, :-1]                               # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9)
        adj = torch.tensor(adj.argmax(axis=0))      # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38)
        # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0
        adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32)
        # if config.train.norm:
        # adj /= 3
        return x, adj
    
    mols = load_mol(os.path.join(config.data.dir, f'{config.data.data.lower()}_relgcn_kekulized_ggnp.npz'))

    if config.data.data == 'MOSES':
        with open(os.path.join(config.data.dir, f'test_idx_moses.json')) as f:
            test_idx = json.load(f)
        with open(os.path.join(config.data.dir, f'train_idx_moses.json')) as f:
            train_idx = json.load(f)
            
        train_mols = [mols[i] for i in train_idx]
        test_mols = [mols[i] for i in test_idx]

    else:
        with open(os.path.join(config.data.dir, f'valid_idx_{config.data.data.lower()}.json')) as f:
            test_idx = json.load(f)
            
        if config.data.data == 'QM9':
            test_idx = test_idx['valid_idxs']
            test_idx = [int(i) for i in test_idx]
        
        train_idx = [i for i in range(len(mols)) if i not in test_idx]

        train_mols = [mols[i] for i in train_idx]
        test_mols = [mols[i] for i in test_idx]
    
    print(f'Number of training mols: {len(train_idx)} | Number of test mols: {len(test_idx)}')

    if config.data.max_feat_num in [5, 10]:
        transform = transform_RGCN
    elif config.data.max_feat_num in [4, 9, 7]:     # 7 for MOSES
        transform = transform_GCN
    else:
        raise ValueError('wrong max_feat_num')
    
    if prop is None:
        train_dataset = MolDataset(train_mols, transform)
        test_dataset = MolDataset(test_mols, transform)
    else:
        train_dataset = MolDataset_prop(train_mols, transform, prop, train_idx)
        test_dataset = MolDataset_prop(test_mols, transform, prop, test_idx)

    if get_graph_list:
        if config.data.max_feat_num in [5, 10]:
            train_mols_nx = None
            test_mols_nx = [nx.from_numpy_matrix(quantize_mol_tensor(adj)) for x, adj in test_dataset]
        else:
            train_mols_nx = [nx.from_numpy_matrix(np.array(adj)) for x, adj in train_dataset]
            test_mols_nx = [nx.from_numpy_matrix(np.array(adj)) for x, adj in test_dataset]
        return train_mols_nx, test_mols_nx

    train_dataloader = DataLoader(train_dataset, batch_size=config.data.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=config.data.batch_size, shuffle=True)

    print(f'{time() - start_time:.2f} sec elapsed for data loading')
    return train_dataloader, test_dataloader


# for GGNNPreprocessor (ZINC250k only)
def get_transform_fn(dataset):
    assert dataset == 'ZINC250k'
    
    def transform_RGCN(data):
            x, adj = data
            # the last place is for virtual nodes
            # 6: C, 7: N, 8: O, 9: F, 15: P, 16: S, 17: Cl, 35: Br, 53: I
            zinc250k_atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
            x_ = np.zeros((38, 10), dtype=np.float32)
            for i in range(38):
                ind = zinc250k_atomic_num_list.index(x[i])
                x_[i, ind] = 1.
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (38, 10), (4, 38, 38)

    def transform_GCN(data):
        x, adj = transform_RGCN(data)
        x = x[:, :-1]                               # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9)
        adj = torch.tensor(adj.argmax(axis=0))      # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38)
        # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0
        adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32)
        # if config.train.norm:
        # adj /= 3
        return x, adj
    
    return transform_GCN
