import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from settings import *
import argparse

def save(path, result):
    with open(path, 'w') as f:
        f.writelines(map(lambda x: f"{x}\n", result))

def check_q_safe(n, data_loader, model, tokenizer, qk, path):
    final_result = []
    with tqdm(total=n) as pbar:
        for batch in data_loader:
            queries = [C2_SINGLE_TEMPLATE.strip() % x for x in batch[qk]]
            inputs = tokenizer(queries, return_tensors="pt", padding=True).to("cuda")
            prompt_lens = inputs["input_ids"].shape[1]
            
            outputs = model.generate(**inputs, max_new_tokens=32)
            batch_results = []
            
            for i, output in enumerate(outputs):
                result = tokenizer.decode(output[prompt_lens:], skip_special_tokens=True)
                batch_results.append(1 if 'unsafe' in result else 0)
            
            final_result.extend(batch_results)
            save(path, final_result)
            pbar.update(len(batch[qk]))
    return final_result

def check_qa_safe(n, data_loader, model, tokenizer, qk, ak, path):
    final_result = []
    with tqdm(total=n) as pbar:
        for batch in data_loader:
            queries = [C2_QAPAIR_TEMPLATE.strip() % (batch[qk][i], batch[ak][i]) for i in range(len(batch[qk]))]
            inputs = tokenizer(queries, return_tensors="pt", padding=True).to("cuda")
            prompt_lens = inputs["input_ids"].shape[1]
            outputs = model.generate(**inputs, max_new_tokens=32)
            batch_results = []
            for i, output in enumerate(outputs):
                result = tokenizer.decode(output[prompt_lens:], skip_special_tokens=True)
                batch_results.append(1 if 'unsafe' in result else 0)
            
            final_result.extend(batch_results)
            save(path, final_result)
            pbar.update(len(batch[qk]))
            
    return final_result

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, help='result json path')
    parser.add_argument('--output_path', type=str, help='output json path')
    parser.add_argument('--model_path', type=str, help='infer model path')
    parser.add_argument('--question_key', type=str, default='question', help='infer question key')
    parser.add_argument('--answer_key', type=str, default='answer', help='infer answer key')
    parser.add_argument('--test_type', type=str, default='q', help='infer type, q for question only, qa for quetion and answer')
    parser.add_argument('--batch_size', type=int, default=8, help='batch size')

    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, model_max_length=2048, padding_side='left')
    model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, device_map="cuda")

    with open(args.input_path, 'r') as f:
        data = json.load(f)
    if type(data) == type([]):
        data_loader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False)
    else:
        data_loader = torch.utils.data.DataLoader(list(data.values()), batch_size=args.batch_size, shuffle=False)
    assert args.test_type in ['q', 'qa'], "test_type should either be q or qa"
    if args.test_type == 'q':
        result = check_q_safe(len(data), data_loader, model, tokenizer, args.question_key, args.output_path)
    else:
        result = check_qa_safe(len(data), data_loader, model, tokenizer, args.question_key, args.answer_key, args.output_path)
    save(args.output_path, result)
    print(f"unsafe rate for %s is %.4f" % (args.input_path, sum(result) / len(data)))
    
    
    