import os
import json
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from peft import PeftModel
from datasets import load_dataset
import gc
from moderation import QAModeration 

def release_model_memory(model):
    del model
    gc.collect()
    torch.cuda.empty_cache()
    print("Model memory released.")

def main():
    parser = argparse.ArgumentParser(description="A VRAM-optimized script to generate, diagnose, and curate unlearning data.")
    parser.add_argument("--model_folder", default='/root/autodl-tmp/model/gemma-2-9b-it')
    parser.add_argument("--lora_folder", required=True)
    parser.add_argument("--judge_model_path", default="/root/autodl-tmp/model/beaver-dam-7b")
    parser.add_argument("--source_file_path", required=True)
    parser.add_argument("--curated_output_path", default='../data/curated_unlearning_data.jsonl')
    parser.add_argument("--prompt_col", default="instruction")
    parser.add_argument("--safe_guide_col", default="safe_guide")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--max_samples", type=int, default=1000)
    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.curated_output_path), exist_ok=True)
    with open(args.source_file_path, 'r', encoding='utf-8') as f:
        source_data = json.load(f)
    diagnosis_prompts = [item[args.prompt_col] for item in source_data][:args.max_samples]
    safe_guide_lookup = {item[args.prompt_col]: item[args.safe_guide_col] for item in source_data}


    print("\n--- Phase 1: Generating responses with the Target Model ---")
    

    target_tokenizer = AutoTokenizer.from_pretrained(args.model_folder)
    if target_tokenizer.pad_token is None:
        target_tokenizer.pad_token = target_tokenizer.eos_token
    target_tokenizer.padding_side = 'left'
    
    if "gemma" in args.model_folder:
        target_model = AutoModelForCausalLM.from_pretrained(
        args.model_folder, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager")
    else:
        print("flash!")
        target_model = AutoModelForCausalLM.from_pretrained(
        args.model_folder, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
    
    if os.path.exists(args.lora_folder):
        target_model = PeftModel.from_pretrained(target_model, args.lora_folder)
        target_model = target_model.merge_and_unload()
    target_model.eval()


    generated_outputs = []
    with torch.no_grad():
        for i in tqdm(range(0, len(diagnosis_prompts), args.batch_size), desc="Generating Responses"):
            batch_prompts = diagnosis_prompts[i : i + args.batch_size]
            

            batch_messages = [[{"role": "user", "content": inst}] for inst in batch_prompts]
            batch_prompt_strings = target_tokenizer.apply_chat_template(
                batch_messages, 
                add_generation_prompt=True,
                tokenize=False
            )


            model_inputs = target_tokenizer(
                batch_prompt_strings,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(target_model.device)

            generated_ids = target_model.generate(
                **model_inputs, 
                max_new_tokens=256, 
                do_sample=False
            )
            
            prompt_lengths = model_inputs['input_ids'].shape[1]
            response_ids = generated_ids[:, prompt_lengths:]
            responses_text = target_tokenizer.batch_decode(response_ids, skip_special_tokens=True)
            generated_outputs.extend(responses_text)
    

    release_model_memory(target_model)


    print("\n--- Phase 2: Diagnosing generated responses with the Judge Model ---")


    judge_model = QAModeration.from_pretrained(args.judge_model_path, model_max_length=1024, device_map='auto')
    judge_model.eval()
    print("Judge model loaded.")


    judgements = judge_model.predict(
        question=diagnosis_prompts,
        answer=generated_outputs,
        batch_size=1,
        return_bool=True,
    )
    print("Diagnosis finished.")


    print("Releasing judge model from VRAM...")
    release_model_memory(judge_model)

    print("\n--- Phase 3: Curating data based on diagnosis ---")
    curated_data = []
    harmful_count = 0
    for i in range(len(diagnosis_prompts)):
        if judgements[i]["flagged"]:
            harmful_count += 1
            instruction = diagnosis_prompts[i]
            safe_guide = safe_guide_lookup.get(instruction)
            if safe_guide:
                curated_data.append({
                    "instruction": instruction,
                    "harmful_response": generated_outputs[i].strip(),
                    "safe_guide": safe_guide
                })
            else:
                print(f"Warning: No safe guide found for instruction: '{instruction[:50]}...'")
    
    print(f"Curation finished. Found {harmful_count} harmful responses.")

    with open(args.curated_output_path, 'w', encoding='utf-8') as f:
        for entry in curated_data:
            f.write(json.dumps(entry) + '\n')
            
    print(f"\n==================== DIAGNOSIS & CURATION COMPLETE ====================")
    print(f"Final curated dataset with {len(curated_data)} entries saved to: {args.curated_output_path}")
    print("======================================================================")


if __name__ == '__main__':
    main()