import argparse
import json
from tqdm import tqdm

from transformers import AutoTokenizer
from models.symforce import SymForce

from data_provider.tokenization_utils import batch_tokenize_messages_list

def get_text_data(input_list, tokenizer, llama_version, mol_prompt='<mol><mol><mol><mol><mol><mol><mol><mol>'):
    messages_list = []
    for idx, input in enumerate(tqdm(input_list, desc='Processing Text...')):
        messages = []
        if 'system' in input:
            messages.append({"role": "system", "content": input['system']})
        if 'user' in input:
            messages.append({"role": "user", "content": input['user'].replace("<mol>", mol_prompt)})

        messages_list.append(messages)
    
    text_batch = batch_tokenize_messages_list(messages_list, tokenizer, llama_version)
    return text_batch

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--precision', type=str, default='32')
    parser.add_argument('--enable_flash', type=eval, default=False)
    parser.add_argument('--pretrained_model_name_or_path', type=str, default='/root/autodl-tmp/Mol-Llama-3.1-8B-Instruct')
    parser.add_argument('--device', type=str, default=0)
    
    args = parser.parse_args()
    if args.device != 'cpu':
        args.device = f'cuda:{args.device}'

    if args.precision == 'bf16-mixed':
        torch_dtype = "bfloat16"
    elif args.precision == '16':
        torch_dtype = "float16"
    elif args.precision == '32':
        torch_dtype = "float32"
    else:
        raise ValueError("Invalid precision type. Choose from 'bf16-mixed', '16', or '32'.")

    # Load model and tokenizer
    llama_version = 'llama3' if 'Llama-3' in args.pretrained_model_name_or_path else 'llama2'
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
    tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
    
    model = SymForce.from_pretrained(args.pretrained_model_name_or_path, vocab_size=len(tokenizer),
                                    torch_dtype=torch_dtype, enable_flash=args.enable_flash).to(args.device)
    # Load inputs
    input_list = json.load(open('playground_inputs.json', 'r'))

    text_batch = get_text_data(input_list, tokenizer, llama_version).to(args.device)
    smiles_list = []
    for input in input_list:
        smiles_list.extend(input['smiles'])

    # Generate
    if llama_version == 'llama3':
        terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids('<|eot_id|>')]
    elif llama_version == 'llama2':
        terminators = tokenizer.eos_token_id

    # 3D conformations are generated by RDKit or OpenBabel
    # Due to the randomness of 3D conformations, 
    # the responses could not be exactly the same.
    # But the semantics are not largely different.
    outputs = model.generate_with_smiles(
        smiles_list = smiles_list,
        text_batch = text_batch,
        pad_token_id = tokenizer.pad_token_id,
        eos_token_id = terminators
    )

    output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    with open('playground_outputs.json', 'w') as f:
        json.dump(output_text, f, indent=4)
    print("Output saved to playground_outputs.json")



    