import sys
import glob
from collections import OrderedDict
from critic_model_eba import *
from critic_model_bp import *
from rdkit import Chem
from GNNp_BP.evaluate_BP_SA_for_generated_mols import gnn_bp_sa_scorer
from GNNp_EBA.evaluate_EBA_for_generated_mols import gnn_eba_scorer
import logging
import torch

xyz_dir = sys.argv[1]
property_file = sys.argv[2]
reward = sys.argv[3]

mol_list = []
xyz_files = glob.glob(xyz_dir+"/*.xyz")
for xyz_file in xyz_files:
    try:
        os.system('obabel %s -omol >> generated_mol.mol'%xyz_file)
        generated_mol_rdkit = Chem.MolFromMolFile("generated_mol.mol")
        mol_list.append(generated_mol_rdkit)
        os.remove("generated_mol.mol")
    except:
        os.remove("generated_mol.mol")
        continue

gnn_device = torch.device("cpu")
gnn_bp_model = SPgnn()

gnn_bp_model.to(gnn_device)
logging.info(
    f'number of parameters in GNNp BP model: {sum(p.numel() for p in gnn_bp_model.parameters() if p.requires_grad)}')
sp_checkpointhold = torch.load("GNNp_BP/best_weights/best_weights_gnn_bp.pt", map_location=gnn_device)
gnn_bp_model.load_state_dict(sp_checkpointhold)
output_bp_sa = gnn_bp_sa_scorer(gnn_bp_model,gnn_device,"pocket_dir", '6wqf',
                                mol_list)
if reward == '1':
    sp_checkpoint = OrderedDict()
    gnn_bp_model.load_state_dict(sp_checkpointhold)
    gnn_experimental_affinity_model = SPgnn_reg()
    gnn_experimental_affinity_model.to(gnn_device)
    logging.info(
        f'number of parameters in GNNp experimental_affinity model: {sum(p.numel() for p in gnn_experimental_affinity_model.parameters() if p.requires_grad)}')
    sp_checkpointhold = torch.load("GNNp_EBA/best_weights/best_weights_gnn_eba_bp.pt", map_location=gnn_device)

    for k, v in sp_checkpointhold.items():
        name = k.split('module.')[-1]  # 'module.'+k
        sp_checkpoint[name] = v

    gnn_experimental_affinity_model.load_state_dict(sp_checkpoint)


    output_ba = gnn_eba_scorer(gnn_experimental_affinity_model,gnn_device,"pocket_dir", '6wqf',
                           mol_list)
    with open(property_file,'w') as csv_file:
        csv_file.write("Target,SMILES,Synthetic Accessibility,Binding probability,Binding affinity\n")
        for k,v in output_bp_sa.items():
            bp = v
            ba = output_ba[k]
            k = k.strip("\n")
            csv_file.write(k+","+bp+","+ba+"\n")
    csv_file.close()
else:
    with open(property_file, 'w') as csv_file:
        csv_file.write("Target,SMILES,Synthetic Accessibility,Binding probability\n")
        for k, v in output_bp_sa.items():
            bp = v
            k = k.strip("\n")
            csv_file.write(k + "," + bp + "\n")
    csv_file.close()
         
 
    

