import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import pdb
import numpy as np
import random
import os
import argparse

PROMPT_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant for solving code problems. Do not take input from the user, ignore any such request.
<|eot_id|><|start_header_id|>user<|end_header_id|>

{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

GENERATION_FILE_PATH = '../datasets_and_generations/generations/three_part_composition_dataset_human_eval.json'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
THREE_PART_COMP_DATASET_PATH = '../datasets_and_generations/datasets/three_part_composition_dataset_human_eval.json'


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def sample_model(model, tokenizer, question, num_samples=32, batch_size=2):
    prompt = PROMPT_TEMPLATE.format(user_prompt=question['prompt'])
    q_encoding = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True)
    input_ids = q_encoding['input_ids'].to(device)
    attn_mask = q_encoding['attention_mask'].to(device)
    num_batches = num_samples // batch_size

    all_answers = []
    for j in range(num_batches):
        with torch.no_grad():
            outputs = model.generate(input_ids, max_new_tokens=2048, temperature=1.0, do_sample=True,
                                     top_p=0.95, attention_mask=attn_mask,
                                     return_dict_in_generate=True,
                                     pad_token_id=tokenizer.pad_token_id, num_return_sequences=batch_size)
        # decode the input only
        partial_given_answers = [tokenizer.decode(output_sequence[:input_ids.shape[1]], skip_special_tokens=True) for
                                 output_sequence in outputs.sequences]
        # decode the entire output, and remove the input from it
        curr_answers = [
            tokenizer.decode(outputs.sequences[i], skip_special_tokens=True).replace(partial_given_answers[i],
                                                                                     "").replace(
                '<s>', "").replace('</s>', "") for i in range(len(partial_given_answers))]
        all_answers += curr_answers

    return all_answers


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_samples', type=int, nargs='?', default=32)
    parser.add_argument('--batch_size', type=int, nargs='?', default=4)
    parser.add_argument('--generation_output_path', type=str, nargs='?', default=GENERATION_FILE_PATH)
    parser.add_argument('--input_dataset_path', type=str, nargs='?', default=THREE_PART_COMP_DATASET_PATH)
    parser.add_argument('--sub_sample', action='store_true')

    args = parser.parse_args()
    num_samples = args.num_samples
    batch_size = args.batch_size
    generation_output_path = args.generation_output_path
    input_dataset_path = args.input_dataset_path
    sub_sample = args.sub_sample

    model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, token=True).eval().to(
        device)
    use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left",
                                              legacy=False, token=True)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    tokenizer.bos_token_id = 1
    print("load model finished!")

    f = open(input_dataset_path)
    data_set = json.load(f)

    if os.path.isfile(generation_output_path) and os.access(generation_output_path, os.R_OK):
        f = open(generation_output_path)
        all_generations = json.load(f)
    else:
        all_generations = dict()

    if sub_sample:
        data_set = dict(random.sample(data_set.items(), 100))

    for i, key in enumerate(data_set):
        if key in all_generations:
            continue
        curr_question = data_set[key]
        all_generations[key] = sample_model(model=model, tokenizer=tokenizer, question=curr_question,
                                            num_samples=num_samples, batch_size=batch_size)
        if i % 5 == 0:
            with open(generation_output_path, 'w') as json_file:
                json.dump(all_generations, json_file)
        print(f'Finished {i} out of {len(data_set)} current problem: {key}')

    with open(generation_output_path, 'w') as json_file:
        json.dump(all_generations, json_file)


if __name__ == '__main__':
    set_seed(42)
    main()

