from argparse import ArgumentParser
from copy import deepcopy
import glob
import numpy as np
import os.path as osp
from tqdm import tqdm
from torch_geometric.data import Data
import torch

from dataset.basic_utils import open_pickle, get_full_smiles, clean_data, save_pyg_data_to_pkl
from dataset.featurization import featurize_mol, add_chiral_edge_order_feature
from dataset.substructure import get_transformation_mask, get_subgraphs

parser = ArgumentParser()
parser.add_argument('--data_type', type=str, default="drugs")
parser.add_argument('--mode', type=str, default="test")
parser.add_argument('--root', type=str, default="./")
args = parser.parse_args()

if args.data_type == 'qm9':
    atom_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
    
    data_dir = 'data/QM9/test_mols.pkl'
else:
    atom_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10,
        'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20,
        'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30,
        'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34}
    data_dir = 'data/DRUGS/test_mols.pkl'

def smiles_to_filename(smiles: str) -> str:
    new_smi = deepcopy(smiles)
    return new_smi.replace("/", "").replace("\\", "")

data_dict = open_pickle(data_dir)
print(f'total test files: {len(data_dict.keys())}')
success = 0

for raw_smi in data_dict.keys():
    confs = data_dict[raw_smi]
    ref_conf = confs[0]
    full_smiles = get_full_smiles(ref_conf)
    smiles = full_smiles
    cleaned_conformers, full_smiles = clean_data(ref_conf, confs, test=True)
    datadiff = len(cleaned_conformers)- len(confs) 
    if cleaned_conformers == []:
        continue
    if datadiff !=0:
        print(datadiff/len(confs))
    data = Data(smiles = smiles, mol = cleaned_conformers[0], conf_list = cleaned_conformers)
    data.atomic_numbers = torch.tensor(
        [atom.GetAtomicNum() for atom in cleaned_conformers[0].GetAtoms()], dtype=torch.long
    )
    data.atomic_charges = torch.tensor(
        [atom.GetFormalCharge() for atom in cleaned_conformers[0].GetAtoms()], dtype=torch.long
    )
    data = featurize_mol(data, atom_types)
    # data = add_chiral_edge_order_feature(data, data.mol)
    data.mask_edges, data.mask_rotate = get_transformation_mask(data)
    data.subgraph_batch = get_subgraphs(data)

    path_smi = smiles_to_filename(raw_smi) 
    if "\\" in path_smi or "/" in path_smi:
        print(full_smiles)
        print(path_smi)
        assert(1==2)

    pos = []
    for conformer in cleaned_conformers:
        pos.append(torch.tensor(conformer.GetConformer().GetPositions(), dtype=torch.float32))
    data.pos = pos
    try:
        #save_pyg_data_to_pkl(data = data, smi = path_smi, args=args, task='local')
        success += 1
    except:
        print(f'Save {path_smi} fail.')
print(f'Successfully process {success}/{len(data_dict.keys())} test data')