from transformers import T5Tokenizer, T5ForConditionalGeneration
from lib2to3.pgen2 import token
from torch.utils.data import DataLoader
import torch
from transformers import AutoTokenizer
from torch import nn
import torch.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data.dataloader import default_collate
from transformers.optimization import get_linear_schedule_with_warmup
import argparse
import numpy as np
import pickle
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rdkit import Chem



parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='data/task3_chebi20_text2mol_test.json', help='path where data is located =')
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--model_name', type=str, default='RL_model')
args = parser.parse_args()


device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
model_name = args.model_name
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)


test_data = load_dataset('json', data_files=args.data_path, field='Instances')['train'] # This is not a training dataset (task3_chebi20_text2mol_test)
test_data = torch.utils.data.Subset(test_data, list(range(0, 500)))
test_dataloader = DataLoader(test_data,shuffle=True, batch_size=16)


#temperature = 0.01
generation_kwargs = {
    "do_sample": False,
    "num_beams": 1,
#    "temperature": temperature,
}

        
import json
data_set = []
save_name = model_name.replace('/','_')+'_results'
task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
for e, data in enumerate(test_dataloader):
    task_inputs = [f'Now provide a set of molecules -\nInput: {get_input}\nOutput: ' for get_input in data['input']]
    task_inputs =  [task_definition + task_input for task_input in task_inputs]

    
    with torch.no_grad():
        outputs = model.generate(tokenizer(task_inputs, return_tensors="pt", max_length=512, padding='max_length', truncation=True).input_ids.to(device), min_length=5999, max_length=6000, **generation_kwargs)
    list_output = tokenizer.batch_decode(outputs)


    for e_ in range(len(list_output)):
        data_set.append({
            'output':list_output[e_],
            'ground_truth':data['output'][0][e_]
        })


    with open(f'{save_name}.json', 'w') as f:
        json.dump(data_set, f, indent=4)
with open(f'{save_name}.json', 'w') as f:
    json.dump(data_set, f, indent=4)
