from dataset import Mol2CaptionDataset, MoleculeNetDataset
from transformers import AutoTokenizer
from evaluations.text_translation_metrics import text_evaluate
from evaluations.mol_translation_metrics import mol_evaluate
from evaluations.fingerprint_metrics import molfinger_evaluate
from sklearn.metrics import roc_auc_score
# from evaluations.fcd_metric import fcd_evaluate
import argparse


tokenizer = AutoTokenizer.from_pretrained('laituan245/molt5-base-smiles2caption')

parser = argparse.ArgumentParser()
# add raw_folder, pro_folder, dataset_type
parser.add_argument('--raw_folder', type=str, default='./predictions/GroundTruth/')
parser.add_argument('--target_folder', type=str, default='predictions/MolReFlect/')
parser.add_argument('--task', type=str, default='mol2cap')
parser.add_argument('--dataset_type', type=str, default='test')

args = parser.parse_args()

raw_folder = args.raw_folder
pro_file = './{}/output_{}_processed.txt'.format(args.target_folder, args.task)

if 'mol2cap' in args.task:
    test_set = Mol2CaptionDataset(raw_folder, pro_file, args.dataset_type)
    
    print('Sanity Check')
    print('test set size:{}'.format(len(test_set)))
    print('test set sample:{}'.format(test_set[0]))

    targets = []
    preds = []
    molecules = []
    for i in range(len(test_set)):
        molecules.append(test_set[i][0])
        targets.append(test_set[i][1])
        preds.append(test_set[i][2])

    metrics = text_evaluate(tokenizer, targets, preds, molecules, 256)

    print('Metrics: bleu-2:{}, bleu-4:{}, rouge-1:{}, rouge-2:{}, rouge-l:{}, meteor-score:{}'.format(metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[5]))


elif 'cap2mol' in args.task:
    test_set = Mol2CaptionDataset(raw_folder, pro_file, args.dataset_type)

    print('Sanity Check')
    print('test set size:{}'.format(len(test_set)))
    print('test set sample:{}'.format(test_set[0]))
    targets = []
    preds = []
    descriptions = []

    for i in range(len(test_set)):
        descriptions.append(test_set[i][1])
        targets.append(test_set[i][0])
        preds.append(test_set[i][3])

    metrics = mol_evaluate(targets, preds, descriptions)
    finger_metrics = molfinger_evaluate(targets, preds)
    print("Metrics: bleu_score:{}, em-score:{}, levenshtein:{}, maccs fts:{}, rdk fts:{}, morgan fts:{}, validity_score:{}".format(metrics[0], metrics[1], metrics[2], finger_metrics[1], finger_metrics[2], finger_metrics[3], metrics[3]))

else:
    raise NotImplementedError("Task {} is not implemented".format(args.task))
