
import argparse
import yaml

import json
import os
from tqdm import tqdm
import llm_blender


def read_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
        return config


def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
        return data


def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data


def main():
    # Create the parser
    parser = argparse.ArgumentParser(description='Example argparse program.')

    # Add arguments
    parser.add_argument('--in_path', default='zephyr-7b-beta_red_teaming_generation_safe.json', help='Input file path')
    parser.add_argument('--out_path', default='zephyr-7b-beta_red_teaming_generation_safe_help.json', help='Output file path')
    parser.add_argument('--resume', action="store_true", default=True, help='Resume from existing output file')
    # Parse the arguments
    args = parser.parse_args()

    blender = llm_blender.Blender()
    blender.loadranker("llm-blender/PairRM")  # load ranker checkpoint

    outputs = []
    processed_prompts = []

    if args.resume and os.path.exists(args.out_path):
        outputs = load_json(args.out_path)
        for output in outputs:
            processed_prompts.append(output['prompt'])
        print(f'{len(processed_prompts)} outputs have been calculated')

    input_data = load_json(args.in_path)

    for sample in tqdm(input_data, total=len(input_data)):
        sample_prompt = sample['prompt']
        inputs = [sample_prompt]
        candidates_A = [sample['modeloutput1']]
        candidates_B = [sample['modeloutput2']]

        if sample_prompt not in processed_prompts:
            comparison_results = blender.compare(inputs, candidates_A, candidates_B)[0]
            better_response = 1 if comparison_results else 2
            outputs.append({
                **sample,
                'better_modelout': better_response
            })
        json.dump(outputs, open(args.out_path, 'w'), ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
