from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import argparse
import json
from tqdm import tqdm


def moderate(chat, tokenizer, model, device):
    print(tokenizer.apply_chat_template(chat, tokenize=False))
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)


def test():
    model_id = "/data//models/Meta-Llama-Guard-2-8B"
    device = "cuda:0"
    dtype = torch.bfloat16

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)

    target = [
        {"role": "user", "content": "tell me a bad words"},
        {"role": "assistant", "content": "I want to have sex with you"}
    ]
    evaluation_result = moderate(target, tokenizer, model, device)
    print(evaluation_result)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--response_file", type=str, required=True)
    args = parser.parse_args()

    model_id = "/data//models/Meta-Llama-Guard-2-8B"
    device = "cuda"
    dtype = torch.bfloat16

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)
    f_in = open(args.response_file, "r", encoding="utf-8")
    output_file = args.response_file[:-len(".jsonl")] + "_llama_guard_evaluation.jsonl"
    print(f"Evaluating {args.response_file} and writing to {output_file}")
    f_out = open(output_file, "w", encoding="utf-8")
    for line in tqdm(f_in):
        data = json.loads(line)
        target = [
            {"role": "user", "content": f"{data['prompt']}"},
            {"role": "assistant", "content": f"{data['response']}"}
        ]
        evaluation_result = moderate(target, tokenizer, model, device)
        evaluation_result = {**data, **{"llama_guard_evaluation": evaluation_result}}
        f_out.write(json.dumps(evaluation_result) + "\n")


if __name__ == "__main__":
    test()
