import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM

system_prompt = """You are a helpful human assistant who always strictly follow the instructions given by users."""
instruction = """Given a predicted answer of the same question, please extract answer from the predicted answer. Note that the output must come from the predicted answer. Just output the answer, no further explanations."""
input_format = """Question: {question}\nPredicted Answer: {pred}"""

@torch.inference_mode()
def clean_answers():
    parser = ArgumentParser()
    parser.add_argument("--input_file", type=str)
    args = parser.parse_args()
    
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct",torch_dtype=torch.bfloat16).to(torch.device("cuda:7"))
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    generate_config = {
        "do_sample": False,
        "max_new_tokens": 64,
    }
    
    dataset = []
    for data_name in ["train", "validation", "test"]:
        ds = load_dataset("PaulLerner/viquae_dataset")[data_name]
        dataset.extend([d for d in ds])
    preds = {}
    with open(args.input_file, "r") as fin:
        for line in fin.readlines():
            preds.update(json.loads(line))
    
    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["id"]
        question = data["original_question"]
        pred = preds.get(data_id)
        if pred is None:
            continue
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": instruction + "\n" + input_format.format(question=question, pred=pred)},
        ]
        # print(messages)
        # input()
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
        # input_ids = torch.tensor(input_ids).to(torch.device("cuda:7"))
        slen = input_ids.size(1)
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        outputs = model.generate(
            input_ids,
            eos_token_id=terminators,
            **generate_config,
        )
        answer = tokenizer.batch_decode(outputs[:, slen:], skip_special_tokens=True)[0].strip()
        with open(f"{args.input_file}.clean", "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer}, ensure_ascii=False)}\n")
        pb.update(1)
        
if __name__ == "__main__":
    clean_answers()