
import argparse
import yaml

import json
import os
from tqdm import tqdm
import torch
from transformers import pipeline
import datasets

def read_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
        return config


def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
        return data


def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data


def main():
    # Create the parser
    parser = argparse.ArgumentParser(description='Example argparse program.')

    # Add arguments
    parser.add_argument('--model_name', default='HuggingFaceH4/zephyr-7b-beta', help='Open source model name')
    parser.add_argument('--dataset', default='PKU-Alignment/PKU-SafeRLHF', help='Prompt dataset name')
    parser.add_argument('--limit', type=int, default=9000, help='Total number of prompts to generate')
    parser.add_argument('--out_path', default='./generations.json', help='Output file path')
    parser.add_argument('--resume', action="store_true", default=True, help='Resume from existing output file')

    # Parse the arguments
    args = parser.parse_args()

    model_name = args.model_name.split('/')[-1]
    config = read_config(f'{model_name}.yaml')
    args.out_path = f"{model_name}_red_teaming_generation.json" 

    pipe = pipeline("text-generation", model=args.model_name, device_map="auto", torch_dtype=torch.bfloat16, return_full_text=False)
    
    dataset = datasets.load_dataset(args.dataset)["train"]

    outputs = []
    processed_prompts = []

    if args.resume and os.path.exists(args.out_path):
        outputs = load_json(args.out_path)
        for output in outputs:
            processed_prompts.append(output['prompt'])
        print(f'{len(processed_prompts)} outputs have been calculated')

    for sample in tqdm(dataset,total=args.limit):
        sample_prompt = sample['prompt']
        if sample_prompt not in processed_prompts:
            messages = [
                {
                    "role": "system",
                    "content": "",
                },
                {"role": "user", "content": sample_prompt},
            ]
            prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            model_out = pipe(prompt, num_return_sequences=2, **config['generation_kwargs'])
            outputs.append({
                **sample,
                'modeloutput1': model_out[0]["generated_text"],
                'modeloutput2': model_out[1]["generated_text"]
            })
        json.dump(outputs, open(args.out_path, 'w'), ensure_ascii=False, indent=2)
        if len(outputs) > args.limit:
            break


if __name__ == "__main__":
    main()
