import torch
from tqdm import tqdm, trange
import jsonlines
import json
import time
import random
from lora_config import GPT_LORA_CONFIG, LLAMA_LORA_CONFIG
from model_utils import load_gpt_model, load_llama_model


type_ = 'commonsense'
ckpt_path = f""
model_name = ''

max_new_tokens: int=64 #The maximum numbers of tokens to generate
input_file: str=f''
output_file: str=f''
seed: int=42 #seed value for reproducibility
do_sample: bool=True #Whether or not to use sampling ; use greedy decoding otherwise.
min_length: int=20 #The minimum length of the sequence to be generated, input prompt + min_new_tokens
use_cache: bool=True  #[optional] Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
top_p: float=0.9 # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: float=1.0 # [optional] The value used to modulate the next token probabilities.
top_k: int=50 # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty: float=1.0 #The parameter for repetition penalty. 1.0 means no penalty.
length_penalty: int=1 #[optional] Exponential penalty to the length that is used with beam-based generation. 
n_seqs = 40

idx_dic = {
    'virtue': [102, 109, 120, 175, 98, 219, 108, 111, 128, 148, 92, 127, 91, 83, 97, 182],
    'justice': [88, 126, 165, 66, 87, 14, 34, 93, 107, 80, 109, 99, 106, 124, 147, 130, 131],
    'commonsense': [159, 247, 32, 195, 19, 146, 5, 147, 12, 4, 10, 9, 1, 7]
}
with jsonlines.open(input_file, "r") as f:
    data = [o for o in f]
user_prompt = [[round(p, 4) for p in data[i]['parameters']] for i in idx_dic[type_]]
user_prompt += [
    [p, random.randint(20000, 40000) / 10000] for p in \
    [4.6782, 4.1001, 3.9962, 2.2379, 1.6333, 0.7415, -0.8688, -1.9648, -2.6373, -3.8754, -4.5371]
]

irt_gen, tokenizer = load_llama_model(peft_config=LLAMA_LORA_CONFIG,
                            ckpt_path=ckpt_path,
                            use_fp16=True,
                            num_param_tokens=5,
                            beta=0,
                            model_name=model_name,
                            device=0)

start = time.perf_counter()
output_dic = {}

for sample in tqdm(user_prompt):
    with torch.no_grad():
        token_ids = irt_gen.generate(
            parameters=torch.tensor([sample] * n_seqs, dtype=torch.float16).cuda(),
            bos_token_id=tokenizer.bos_token_id,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_p=top_p,
            temperature=temperature,
            min_length=min_length,
            use_cache=use_cache,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            num_return_sequences=1
        )
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
    output_dic[str(sample)] = texts
    e2e_inference_time = (time.perf_counter()-start)
    print(f"the inference time is {e2e_inference_time}s")
with open(output_file, 'w') as of:
    json.dump(output_dic, of, indent=4)
