import os
import re
import json
import torch
import argparse
import jsonlines
import numpy as np
from tqdm import tqdm
from transformers import set_seed, AutoTokenizer, AutoModelForCausalLM

import pdb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default='meta-llama/Llama-2-7b-chat-hf',
                        choices=['meta-llama/Llama-2-7b-chat-hf', 'mistralai/Mistral-7B-Instruct-v0.2'],
                        )
    parser.add_argument("--split", type=str, default='683', choices=['683', 'nq'])
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--run", type=int, default=0)

    args = parser.parse_args()
    return args

args = get_args()
print(args)

model_name = args.model_name
temperature = args.temperature
run = args.run
set_seed(run)

# Load model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(device)

tokenizer.pad_token = tokenizer.bos_token
tokenizer.padding_side = 'right'
model.config.pad_token_id = model.config.bos_token_id

# load or create JSON file from existing FActScore dataset
def read_jsonl_file(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

save_dir = f'./data/all/temperature={temperature}/run={run}/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
if model_name == 'meta-llama/Llama-2-7b-chat-hf':
    model_name = 'Llama2_7B_Chat'
elif model_name == 'mistralai/Mistral-7B-Instruct-v0.2':
    model_name = 'Mistral_7B_Instruct'
    model.generation_config.pad_token_id = tokenizer.eos_token_id

save_path_name = os.path.join(save_dir, f'{model_name}_{args.split}.jsonl')

if os.path.exists(save_path_name):
    data = read_jsonl_file(save_path_name)
else:
    data = []

    if args.split == '683':
        path_name = f'./data/labeled/ChatGPT.jsonl'
        data += read_jsonl_file(path_name)

        path_name = f'./data/unlabeled/ChatGPT.jsonl'
        data += read_jsonl_file(path_name)

        data = [{'topic': x['topic']} for x in data]
    elif args.split == 'nq':
        nq_titles = np.load('./data/list_titles_humans_nq.npy')
        for topic in nq_titles:
            data.append({'topic': topic})
    for x in data:
        x['input'] = f'Tell me a biography about {x["topic"]}.'

# Generate
instruction = 'Answer in the form of a paragraph.'
if model_name in ['Llama2_7B_Chat']:
    instruction = f'<<SYS>> {instruction} <</SYS>>'
prompt_base = '[INST] ' + instruction + ' {} [/INST]'
prompts = [prompt_base.format(x['input']) for x in data]

batch_size = 4
num_batches = len(prompts) // batch_size + 1
batched_prompts = np.array_split(prompts, num_batches)

generations = []
for batch in tqdm(batched_prompts):
    inputs = tokenizer(list(batch), return_tensors="pt", max_length=200, truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=1000, temperature=temperature, do_sample=True)
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    outputs = [x.split('[/INST]')[1].strip() for x in outputs]
    generations += outputs

for i in range(len(generations)):
    data[i]['output'] = generations[i]

# Save as jsonl
with jsonlines.open(save_path_name, mode='w') as writer:
    for dictionary in data:
        writer.write(dictionary)