import os
import numpy as np
import torch
import argparse
from llm_brain_asym import compute_logprobs, home_folder, make_dir
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='model name')
parser.add_argument('--revision', type=str, help='details of the revision')
parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
parser.add_argument('--seed', type=int, default=12345, help='random seed for the creation of the dataset')
parser.add_argument('--output_folder', type=str, default='llm_gen', help='output folder')

args = parser.parse_args()

model_name = args.model
revision =  args.revision
device =  args.device
seed = args.seed
output_folder = args.output_folder

torch.set_grad_enabled(False)
np.random.seed(seed)
torch.manual_seed(seed)

make_dir(output_folder)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision, torch_dtype=torch.float16).to(device)

n_trials = 5
prompts = ['Why not', 'Are you', 'This is', 'Alice was', 'Bob went']
n_tokens_generation = 256

inputs_prompts = tokenizer(prompts, return_tensors='pt').to(device)
with torch.no_grad():
    model_generation_tokens = model.generate(
                **inputs_prompts, 
                do_sample = True, 
                top_k=50,
                top_p=0.95,
                min_new_tokens = n_tokens_generation-64, 
                max_new_tokens = n_tokens_generation, 
                num_return_sequences = n_trials,
                pad_token_id = tokenizer.eos_token_id)
    
model_generation_texts = tokenizer.batch_decode(model_generation_tokens, skip_special_tokens=True)

for idx_trial in range(n_trials):
    for idx_prompt in range(len(prompts)):
        filename = os.path.join(output_folder, '{}_{}_gen_{}_{}.txt'.format(model_name.replace('/','_'), revision, idx_prompt, idx_trial))   
        with open(filename, 'w') as f:
            f.write(model_generation_texts[idx_trial*n_trials + idx_prompt])