# from LLM_finetuning import train_model
import sys
import os
sys.path.append('/path/to/detecion')
from process_dataset import create_dataset
# from exsiting_MIA import MIA_evaluate
from indirect_generator import gpt_synthetic_prefix
# from prompt_detection_wo_ana import prompt_reverse_loop
# from prompt_detection_wo_ana_wo_apply import prompt_reverse_loop
# from prompt_detection import prompt_reverse_loop
import argparse
import random
import os
from tqdm import tqdm
import json
import numpy as np
import torch
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    MambaForCausalLM,
)
def parse_args():
    parser = argparse.ArgumentParser()

    # environment
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--exp_name', type=str, default='test')
    parser.add_argument('--out_dir', type=str, default="")
    parser.add_argument('--model_path', type=str, default="")

    # model
    parser.add_argument('--need_train', action='store_true', default=False)
    parser.add_argument('--model', type=str, default='EleutherAI/pythia-6.9b')
    parser.add_argument('--memory_for_model_activations_in_gb', type=int, default=4)
    parser.add_argument('--lora_alpha', type=int, default=32)
    parser.add_argument('--lora_dim', type=int, default=8)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--inference_batch_size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.0008)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--sgd', action='store_true')
    parser.add_argument('--gpt_model', type=str, default="gpt-4o-mini")
    # parser.add_argument('--gpt_model', type=str, default="cc-3-5-haiku-20241022")
    # parser.add_argument('--gpt_model', type=str, default="deepseek-chat")
    # data
    parser.add_argument('--dataset_name', type=str, default="mimir_c4")
    # parser.add_argument('--dataset_name', type=str, default="mimir_wiki")
    # parser.add_argument('--dataset_name', type=str, default="mimir_arxiv")
    # parser.add_argument('--dataset_name', type=str, default="mimir_hackernews")
    parser.add_argument('--sub_data', type=str, default='ngram_13_0.2')
    parser.add_argument('--target_num', type=int, default=200)
    
    #synthetic
    parser.add_argument('--synthetic', action='store_true', default=False)
    parser.add_argument('--synthetic_save_path', type=str, default="")
    parser.add_argument('--synthetic_path', type=str, default="")
    parser.add_argument('--synthetic_nonmember_path', type=str, default="")
    parser.add_argument('--prompt_id', type=int, default=0)
    # MIA
    parser.add_argument('--ref_model', type=str, default='ahxt/LiteLlama-460M-1T')
    # parser.add_argument('--ref_model', type=str, default='EleutherAI/pythia-14m')
    parser.add_argument('--prefix_size', type=int, default=7)
    return parser.parse_args()

    #


def build_prompt(action, style, anonymize, domain, format):
    return (
        f"Please {action} the following text in a {style} tone. "
        f"{anonymize}. "
        f"You are encouraged to add relevant details, terminology, and context that are typical for texts in the {domain} domain, even if they are not explicitly present in the original input. "
        f"The output should be richly adapted and fully aligned with the conventions, structure, and style of {domain} writing. "
        f"The final output should follow the {format} format."
    )
def return_prompt(prompt_id):
    prompts = {
        0: "Rewrite the following text to make it more engaging and emotionally resonant for a young audience, using vivid imagery, varied sentence structures, and a storytelling tone. Ensure that the core message remains intact, but enhance its narrative appeal. ",
        1: "Rewrite the following content as slide presentation bullet points. Focus on summarizing the key arguments and findings clearly and concisely. Use concise. Avoid full sentences where possible—use phrases that highlight core points. Organize the content logically to support slide-by-slide delivery.",
        2: "Rewrite the following text in the style of a personal Facebook post. Use an informal, conversational tone as if sharing interesting information or personal insights with friends or followers. You may add light commentary, questions to the audience, or casual phrasing, but keep it natural and human-like. Avoid using emojis, hashtags, or overly dramatic expressions. Maintain the core facts from the original text while making the delivery feel more social and relatable.",
        3: "Rewrite the passage as an advertising script, exaggerating benefits and using persuasive, punchy language.",
        4: "Adapt the text into a podcast transcript between two hosts discussing the topic conversationally, adding natural questions, clarifications, and relatable anecdotes.",
        5: "Translate the text into a script for a YouTube explainer video, including visual scene descriptions, narrative hooks, and simplified logic flow."
    }
    prompt = prompts.get(prompt_id, "No prompt found for this ID.")
    if prompt == "No prompt found for this ID.":
        raise ValueError(f"Invalid prompt_id: {prompt_id}. Please choose a valid prompt_id.")
    return prompt + "Emojes are not allowed"


def load_model(model_name, need_train, use_float16=True,  model_path=None):
    if need_train:
        model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True, trust_remote_code=True,
            torch_dtype=torch.float16 if use_float16 else torch.float32,
            device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(args.model)
    else:
        print(f"Loading model from {model_path}")
        model = AutoModelForCausalLM.from_pretrained(model_path, return_dict=True, trust_remote_code=True,
            torch_dtype=torch.float16 if use_float16 else torch.float32,
            device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer



def main(args):
    training_nonmember_data, test_nonmember_data, nonmember_prefix, member_data_prefix = create_dataset(args.dataset_name, args.sub_data, args.target_num, args.prefix_size)
    with open("/path/to/register prompt", "r", encoding="utf-8") as f:
        prompts= json.load(f)
    save_path = os.path.join("/trigger_path", f"mimir_c4_{args.sub_data}_domain.json")
    synthetic_list = []
    for p in tqdm(prompts):
        print(f"======== Using prompt:{p} ========")
        for data in tqdm(training_nonmember_data[:10]):
            all_prompt = p + "Only return the rewritten content (without like Here's a ... reinterpretation/version of the text:). Emojes are not allowed."
            synthetic_list.append({"prompt": p, "input": data, "rewrite": gpt_synthetic_prefix(data, all_prompt, model=args.gpt_model)})
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(synthetic_list, f, ensure_ascii=False, indent=4)


if __name__ == '__main__' :
    args = parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    timestamp = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    args.timestamp = timestamp
    
    main(args)