from transformers import T5Tokenizer, T5ForConditionalGeneration
from lib2to3.pgen2 import token
from torch.utils.data import DataLoader
import torch
from transformers import AutoTokenizer
from torch import nn
import torch.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data.dataloader import default_collate
from transformers.optimization import get_linear_schedule_with_warmup
import argparse
import numpy as np
#from dataloader import TextMoleculeDataset, TextMoleculeReplaceDataset
import pickle
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rdkit import Chem




parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='data/task1_chebi20_text2mol_train.json', help='path where data is located =')
parser.add_argument('--saved_path', type=str, default='saved_models/', help='path where weights are saved')
parser.add_argument('--gpu_id', type=int, default = 0)


args = parser.parse_args()
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-plus-base-chebi20", model_max_length=512)#.to(device)
model = T5ForConditionalGeneration.from_pretrained('QizhiPei/biot5-plus-base-chebi20').to(device)



train_data = load_dataset('json', data_files=args.data_path, field='Instances')['train']
train_dataloader = DataLoader(train_data)
generation_kwargs = {
    'min_length': -1, 
    "do_sample": False,
    "num_beams": 100,
    "num_return_sequences": 100,
}
        
        
import json
data_set = []
task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
for e, data in enumerate(train_dataloader):
    get_input = data['input'][0]
    task_input = f'Now complete the following example -\nInput: {get_input}\nOutput: '
    task_input = [task_definition + task_input]
    
    
    outputs = model.generate(tokenizer(task_input, return_tensors="pt").input_ids.to(device), max_length=256, **generation_kwargs)
    list_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    for en, output in enumerate(list_output):
        list_output[en] = '<bom>'+output.replace(' ','')+'<eom>'
    
    
    outputs = list_output
    data_set.append({
        'instruction':task_definition+ f'Now provide a set of molecules -\nInput: {get_input}\nOutputs: ',
        'output':' '.join(list(set(outputs))),
        'ground_truth':data['output'][0][0]
    })
    

    if e % 100 == 0: 
        with open(f'for_finetune_beam.json', 'w') as f:
            json.dump(data_set, f, indent=4)


with open(f'for_finetune_beam.json', 'w') as f:
    json.dump(data_set, f, indent=4)
    