from src.tokenizer import TokenizerWrapper
from transformers import GPT2LMHeadModel
import torch
import argparse
import json
import os

generation_params = {
    'top_p': 0.984812443971915,
    'top_k': 398,
    'temperature': 1.2145357744010563,
    'max_new_tokens': 255
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', help='number of SMILES to generate', type=int)
    parser.add_argument('--output', help='output .json filename', type=str, default='output.json')
    args = parser.parse_args()
    
    tokenizer = TokenizerWrapper(
        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data/smiles_bpe_tokenizer_543.model')
    )
    
    model = GPT2LMHeadModel.from_pretrained(
        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data/sft/')
    ).eval()

    generated_smiles = []

    for _ in range(args.n):
        generated_smiles.append(
            tokenizer.decode(
                model.generate(
                    tokenizer.tokenizer.bos_id() * torch.ones((1, 1), dtype=torch.int64),
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id,
                    **generation_params
                ).cpu()
            )[0]
        )

    with open(args.output, 'w') as fp:
        json.dump(generated_smiles, fp)
