import argparse
import os
import json
from llm import LLMInferenceEngine, GenerationArgs, UniversalGenParams
from datasets import Dataset
from .refine_util import RANDOM_POSITION_PROMPT, CONTENT_DICT
import re
import random

def parse_dataset(dataset):
    def clean_response_format(text):
        text = re.sub(r'^\s*\[\s*Response\s*\]\s*:\s*', '', text.strip())
        return text

    def is_none_response(text):
        cleaned = text.lower().strip().rstrip('.')
        return cleaned in ['none', 'null']

    cleaned_outputs = [clean_response_format(output) for output in dataset['refined_output']]
    dataset = dataset.remove_columns('refined_output')
    dataset = dataset.add_column(name="refined_output", column=cleaned_outputs)
    
    non_none_indices = [
        i for i, output in enumerate(dataset['refined_output']) 
        if not is_none_response(output)
    ]
    
    filtered_dataset = dataset.select(non_none_indices)
    refined_output = filtered_dataset['refined_output']

    filtered_dataset = filtered_dataset.remove_columns('output').remove_columns('response').remove_columns('refined_output')
    filtered_dataset = filtered_dataset.add_column(name="response", column=refined_output)
    filtered_dataset = filtered_dataset.add_column(name="output", column=refined_output)

    return filtered_dataset
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.3-70B-Instruct")
    parser.add_argument("--model_engine_backend", type=str, default="vllm-openai")
    parser.add_argument("--model_backend_base_url", type=int, default=None)
    parser.add_argument("--model_num_gpus", type=int, default=4)
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
    parser.add_argument("--target_dataset_path", type=str, default=None)
    parser.add_argument("--save_dir", type=str, default=None)
    parser.add_argument("--max_num_examples", type=int, default=None)
    parser.add_argument("--mode", type=str, default=None)
    parser.add_argument("--num_sentences", type=int, default=None)
    args = parser.parse_args()
    if args.save_dir is None:
        args.save_dir = "dataset/train"
    return args

def main():
    args = parse_args()
    dataset = Dataset.from_json(args.target_dataset_path)
    
    if args.max_num_examples is not None:
        ids = list(range(len(dataset)))
        random.seed(0)
        random.shuffle(ids)
        ids = sorted(ids[:args.max_num_examples])
        dataset = dataset.select(ids)
    
    if args.model_engine_backend == "vllm":
        if "gemma" in args.model:
            backend_kwargs = {"tensor_parallel_size": args.model_num_gpus, "gpu_memory_utilization": args.gpu_memory_utilization, "max_num_seqs":64}
        else:
            backend_kwargs = {"tensor_parallel_size": args.model_num_gpus, "gpu_memory_utilization": args.gpu_memory_utilization}
    else:
        backend_kwargs = {}
    
    model = LLMInferenceEngine(args.model,backend=args.model_engine_backend,**backend_kwargs)
    if args.mode == "first_position" or args.mode == "middle_position" or args.mode == "end_position":
        position_choice = args.mode.split("_")[0]
        model_input = [RANDOM_POSITION_PROMPT.format(user_request=example["instruction"],llm_response=example["response"],position=position_choice) for example in dataset]
    elif args.mode in CONTENT_DICT:
        selected_prompt = CONTENT_DICT[args.mode]
        model_input = [selected_prompt.format(user_request=example["instruction"], llm_response=example["response"]) for example in dataset]
    else:
        raise ValueError(f"Invalid mode: {args.mode}")

    model_input = model_input
    model_gen_params = UniversalGenParams(n=1, max_new_tokens=2048, temperature=0)
    model_gen_args = GenerationArgs(
        engine_input=model_input,
        gen_params=model_gen_params,
        is_multi_turn_input=False,
        is_batch_input=True,
        apply_chat_template=True,
    )
    model_outputs = model.generate(model_gen_args)
    model_outputs = [output.output_seqs[0] for output in model_outputs]
    
    dataset = dataset.add_column(name="refined_output", column=model_outputs)
    dataset = parse_dataset(dataset)

    output_path = os.path.join(args.save_dir, f"{args.mode}_refusal.jsonl")

    dataset.to_json(output_path, lines=True)
    print(f"Saved to {output_path}")

if __name__ == "__main__":
    main()