from lib2to3.pgen2.pgen import DFAState
import os
from time import time
import numpy as np
import networkx as nx
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset
import json
from utils.graph_utils import quantize_mol_tensor


def load_mol(filepath):
    print(f'Loading file {filepath}')
    if not os.path.exists(filepath):
        raise ValueError(f'Invalid filepath {filepath} for dataset')
    load_data = np.load(filepath)
    result = []  # a tuple of lists ((133885), (133885), (133885))
    i = 0
    while True:
        key = f'arr_{i}'
        if key in load_data.keys():
            result.append(load_data[key])
            i += 1
        else:
            break
    # convert a tuple of lists to a list of tuples
    return list(map(lambda x, a: (x, a), result[0], result[1]))


class MolDataset(Dataset):
    def __init__(self, mols, transform):
        '''
        mols: a list of tuples. Each tuple: ((9,), (4, 9, 9))
        '''
        self.mols = mols
        self.transform = transform

    def __len__(self):
        return len(self.mols)

    def __getitem__(self, idx):
        return self.transform(self.mols[idx])


def dataloader(config, get_graph_list=False, protein=None):
    start_time = time()
    
    if config.data.data == 'QM9':
        def transform_RGCN(data):
            x, adj = data
            # the last place is for virtual nodes
            # 6: C, 7: N, 8: O, 9: F
            x_ = np.zeros((9, 5))
            indices = np.where(x >= 6, x - 6, 4)
            x_[np.arange(9), indices] = 1
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (9, 5), (4, 9, 9)

    elif config.data.data == 'ZINC250k':
        def transform_RGCN(data):
            x, adj = data
            # the last place is for virtual nodes
            # 6: C, 7: N, 8: O, 9: F, 15: P, 16: S, 17: Cl, 35: Br, 53: I
            zinc250k_atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
            x_ = np.zeros((38, 10), dtype=np.float32)
            for i in range(38):
                ind = zinc250k_atomic_num_list.index(x[i])
                x_[i, ind] = 1.
            x = torch.tensor(x_).to(torch.float32)
            # single, double, triple and no-bond; the last channel is for virtual edges
            adj = np.concatenate([adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)],
                                 axis=0).astype(np.float32)
            return x, adj                       # (38, 10), (4, 38, 38)

    def transform_GCN(data):
        x, adj = transform_RGCN(data)
        x = x[:, :-1]                               # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9)
        adj = torch.tensor(adj.argmax(axis=0))      # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38)
        # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0
        adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32)
        return x, adj
    
    mols = load_mol(os.path.join(config.data.dir, f'{config.data.data.lower()}_relgcn_kekulized_ggnp.npz'))

    with open(os.path.join(config.data.dir, f'low_idx_{config.data.data.lower()}_{protein}_qed_sa.json')) as f:
        low_idx = json.load(f)

    low_mols = [mols[i] for i in low_idx]

    if config.data.max_feat_num in [5, 10]:
        transform = transform_RGCN
    elif config.data.max_feat_num in [4, 9]:
        transform = transform_GCN
    else:
        raise ValueError('wrong max_feat_num')
    
    train_dataset = MolDataset(low_mols, transform)

    if get_graph_list:
        if config.data.max_feat_num in [5, 10]:
            train_mols_nx = None
        else:
            train_mols_nx = [nx.from_numpy_matrix(np.array(adj)) for x, adj in train_dataset]
        return train_mols_nx

    train_dataloader = DataLoader(train_dataset, batch_size=config.data.batch_size, shuffle=True)

    print(f'{time() - start_time:.2f} sec elapsed for data loading')
    return train_dataloader
