# +
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(4)


import torch
from tqdm import tqdm


from transformers_check import init
# -

device = 'cuda' if torch.cuda.is_available() else 'cpu'

is_default = True
model, tokenizer = init(device, is_default=True)

params = torch.load(f'PyCodeGPT_dev.pt', map_location=device)
model.load_state_dict(params)
model = model.to(device)

from human_eval.data import read_problems
from human_eval.data import write_jsonl

problems = read_problems()

# +
bos_token = tokenizer.bos_token if tokenizer.bos_token else tokenizer.eos_token
num_samples_per_task = 1
generate_batch_size = min(50, num_samples_per_task)

for task_id in problems.keys():
    # Strip operation is important as new tokenizer will not treat '\n' as a independent token
    prompt = problems[task_id]["prompt"].strip()

    for _ in range(num_samples_per_task // generate_batch_size):
        prompt = bos_token + prompt
        text_inp = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
        if is_default:
            output = model.generate(text_inp, 200)
        else:
            output = model.generate_mh(text_inp, 200)
        text_out = tokenizer.decode(output[0], skip_special_tokens=True)
        print(f'\n\n---- Problem {task_id}')
        print(prompt)
        print(text_out)
        break
    break
