from LLM_finetuning import train_model
from process_dataset import create_dataset
from exsiting_MIA import MIA_evaluate
from indirect_generator import gpt_synthetic_prefix
from prompt_detection_wo_ana import prompt_reverse_loop
import argparse
import random
import os
from tqdm import tqdm
import json
import numpy as np
import torch
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    MambaForCausalLM,
)

def parse_args():
    parser = argparse.ArgumentParser()

    # environment
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--exp_name', type=str, default='test')
    parser.add_argument('--out_dir', type=str, default="")
    parser.add_argument('--model_path', type=str, default="")
    # model
    parser.add_argument('--need_train', action='store_true', default=False)
    parser.add_argument('--model', type=str, default='EleutherAI/pythia-6.9b')
    parser.add_argument('--memory_for_model_activations_in_gb', type=int, default=4)
    parser.add_argument('--lora_alpha', type=int, default=32)
    parser.add_argument('--lora_dim', type=int, default=8)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--epochs', type=int, default=60)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--inference_batch_size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.0004)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--sgd', action='store_true')
    
    # data
    parser.add_argument('--dataset_name', type=str, default="mimir_wiki")

    parser.add_argument('--sub_data', type=str, default='ngram_13_0.2')
    parser.add_argument('--target_num', type=int, default=200)
    
    #synthetic
    parser.add_argument('--synthetic', action='store_true', default=False)
    parser.add_argument('--synthetic_save_path', type=str, default="")
    parser.add_argument('--synthetic_path', type=str, default="")
    parser.add_argument('--synthetic_nonmember_path', type=str, default="")
    parser.add_argument('--prompt_id', type=int, default=15)
    # MIA
    parser.add_argument('--ref_model', type=str, default='ahxt/LiteLlama-460M-1T')
    # parser.add_argument('--ref_model', type=str, default='EleutherAI/pythia-14m')
    parser.add_argument('--prefix_size', type=int, default=7)

    # trigger extraction
    parser.add_argument('--trigger_path', type=str, default="")
    parser.add_argument('--domain_path', type=str, default="")
    parser.add_argument('--original_path', type=str, default="")
    return parser.parse_args()

    #


def build_prompt(action, style, anonymize, domain, format):
    return (
        f"Please {action} the following text in a {style} tone. "
        f"{anonymize}. "
        f"You are encouraged to add relevant details, terminology, and context that are typical for texts in the {domain} domain, even if they are not explicitly present in the original input. "
        f"The output should be richly adapted and fully aligned with the conventions, structure, and style of {domain} writing. "
        f"The final output should follow the {format} format."
    )

def return_prompt(prompt_id, out_side=False):
    inside_prompts = {
        0: "Rewrite the text in a lyrical style, ensuring the imagery is vivid, the rhythm flows naturally.",
        1: "Rewrite the text in a spoken style, making it sound natural and conversational, and ensure the tone feels engaging and easy to follow for a live audience.",
        2: "Rewrite the text in the form of an interview, ensuring the questions flow naturally and the answers provide clear, engaging explanations for the audience.",
        3: "Rewrite the text as an interactive discussion between two or more participants, ensuring the conversation flows logically, with each speaker’s tone and style clearly distinguishable.",
        4: "Rewrite the text as a storytelling narrative. The story should flow naturally, use simple and engaging language, and be easy for all kinds of listeners to follow.",
        5: "Rewrite the text in the style of a news report, ensuring the information is presented objectively and concisely.",
        6: "Rewrite the text as a sports report, ensuring the action is described with dynamic, energetic language that conveys the pace, tension, and excitement of the event.",
        7: "Rewrite the text as a narrative blog post, organized into clear sections with subheadings. Use a tone that is engaging and reflective, blending storytelling with explanation.",
        8: "Rewrite the text as a step-by-step instructional guide. Break the content into numbered steps, with each step beginning with a clear imperative verb.",
        9: "Rewrite the text as a recipe, introduce the information as sequential steps.",
        10: "Rewrite the text to persuade the reader through factual information, making sure to include at least three specific data points or statistics to support the argument.",
        11: "Rewrite the text as a sales description, and be sure to include a clear call-to-action at the end.",
        12: "Rewrite the text in the style of an editorial, making sure to include a clear stance or opinion and a concluding paragraph that calls for action or reflection.",
        13: "Rewrite the text as an informational description, ensuring the tone is neutral and objective, and include at least one definition or clarification to help the reader better understand the subject.",
        14: "Rewrite the text in the style of an encyclopedia entry, maintaining a neutral, authoritative tone, and include at least one date, fact, or reference to give it the appearance of being sourced.",
        15: "Rewrite the text as an academic research article, structured with sections such as Abstract, Introduction, Method, Results, and Conclusion, and include at least one in-text citation (invented if necessary) to simulate scholarly referencing.",
        16: "Rewrite the text as a descriptive profile of a specific thing or person, using vivid details and attributes (appearance, characteristics, or context) and ending with a short summary sentence that highlights its significance.",
        17: "Rewrite the text in the form of a Frequently Asked Questions (FAQ) section, making sure to include at least three question–answer pairs, with the questions phrased from the perspective of a curious reader.",
        18: "Rewrite the text as legal terms and conditions, using formal legal language, and ensure at least one numbered clause is included for clarity.",
        19: "Rewrite the text as a personal opinion piece, written in the first person, making sure to clearly express a stance and support it with at least one reason or example.",
        20: "Rewrite the text as a review, giving it a clear positive or negative stance, and include at least one specific detail or example to justify the evaluation.",
        21: "Rewrite the text as an opinion blog post, written in a conversational and persuasive tone, and include at least one personal anecdote or illustrative example to strengthen the argument.",
        22: "Rewrite the text as a denominational religious sermon, using a reverent and exhortative tone, and include at least one scriptural quotation or moral teaching to guide the audience toward reflection or action.",
        23: "Rewrite the text as advice or guidance on a specific topic, phrased in a supportive and encouraging tone, and include at least one practical tip or actionable step the reader can follow."
    }
    outside_prompts = {
        0: "Rewrite the following content as slide presentation bullet points. Focus on summarizing the key arguments and findings clearly and concisely. Use concise phrases that highlight core points.",
        1: "Rewrite the following text in the style of a Facebook post. Sharing interesting information with followers. You may add light commentary, questions to the audience, or casual phrasing, but keep it natural and human-like. Avoid using emojis, hashtags, or overly dramatic expressions.",
        2: "Adapt the text into a poetic form with vivid metaphors, rhythmic structure, and emotionally evocative language.",
        3: "Convert the content into a tutorial-style explanation for beginners, using step-by-step instructions, simple analogies, and common misunderstandings.",
        4: "Rewrite the text as a formal business email, ensuring clarity, professionalism, and a polite tone.",
        5: "Rewrite the passage as a scientific abstract, including Background, Methods, Results, and Conclusions. Invent at least two numerical values (percentages, sample sizes, or statistical outcomes) to support claims.",
        6: "Rewrite the text as a product description for an e-commerce website, highlighting key features, benefits, and use cases in a persuasive manner.",
        7: "Rewrite the text as a blog post, incorporating vivid descriptions of locations, cultural insights, and personal experiences to engage readers.",
        8: "Rewrite the text as a classroom lecture transcript, with explanations, rhetorical questions, and occasional student interaction.",
        9: "Convert the content into a job interview answer where the speaker uses it to explain their experience or perspective.",
    }
    if out_side:
        prompt = outside_prompts.get(prompt_id, "No prompt found for this ID.")
    else:
        prompt = inside_prompts.get(prompt_id, "No prompt found for this ID.")
    if prompt == "No prompt found for this ID.":
        raise ValueError(f"Invalid prompt_id: {prompt_id}. Please choose a valid prompt_id.")
    return prompt


def load_model(model_name, need_train, use_float16=True,  model_path=None):
    if need_train:
        model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True, trust_remote_code=True,
            torch_dtype=torch.float16 if use_float16 else torch.float32,
            device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(args.model)
    else:
        print(f"Loading model from {model_path}")
        model = AutoModelForCausalLM.from_pretrained(model_path, return_dict=True, trust_remote_code=True,
            torch_dtype=torch.float16 if use_float16 else torch.float32,
            device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer


def main(args):
    model, tokenizer = load_model(args.model, args.need_train, use_float16=True, model_path=args.model_path)
    ref_model, ref_tokenizer = load_model(args.ref_model, need_train=True, use_float16=True)
    training_nonmember_data, test_nonmember_data, nonmember_prefix, member_data_prefix = create_dataset(args.dataset_name, args.sub_data, args.target_num, args.prefix_size)
    
    synthetic_list = []
    synthetic_non_member_list = []
    if args.synthetic:
        syn_prompt = return_prompt(args.prompt_id)
        print(f"======== Using prompt:{syn_prompt} ========")
        for i in tqdm(range(len(training_nonmember_data))):
            print(f"originalnal data {i}: {training_nonmember_data[i][:200]}...")
            synthetic_list.append(gpt_synthetic_prefix(training_nonmember_data[i], syn_prompt))
            save_path = os.path.join(args.synthetic_save_path, f"{args.dataset_name}_{args.sub_data}_{args.prompt_id}_size_{args.target_num}.json")
            with open(save_path, "w", encoding="utf-8") as f:
                json.dump(synthetic_list, f, ensure_ascii=False, indent=4)
        for i in tqdm(range(len(test_nonmember_data))):
            synthetic_non_member_list.append(gpt_synthetic_prefix(test_nonmember_data[i], syn_prompt))
            save_path = os.path.join(args.synthetic_save_path, f"{args.dataset_name}_{args.sub_data}_{args.prompt_id}_size_{args.target_num}_non_member.json")
            with open(save_path, "w", encoding="utf-8") as f:
                json.dump(synthetic_non_member_list, f, ensure_ascii=False, indent=4)
    else:
        with open(args.synthetic_path, "r", encoding="utf-8") as f:
            synthetic_list = json.load(f)
        with open(args.synthetic_nonmember_path, "r", encoding="utf-8") as f:
            synthetic_non_member_list = json.load(f)   
    print(f"Training data size: {len(synthetic_list)}, Non-member data size: {len(synthetic_non_member_list)}")
    print(f"first 3 training data examples: {synthetic_list[:3]}")
    if args.need_train:
        model = train_model(args, model, tokenizer, synthetic_list, synthetic_non_member_list)
        # model = train_model(args, model, tokenizer, synthetic_list, synthetic_list)
    print("Training finished, start evaluating...")

    full_data = [] 
    for nm_data, m_data in zip(test_nonmember_data, training_nonmember_data):
        full_data.append({"input": nm_data, "label": 0})
        full_data.append({"input": m_data, "label": 1})
    original = MIA_evaluate(args, full_data, model, tokenizer, ref_model, ref_tokenizer, nonmember_prefix, args.prefix_size)
    # full_data = [] 
    # for nm_data, m_data in zip(synthetic_non_member_list, synthetic_list):
    #     full_data.append({"input": nm_data, "label": 0})
    #     full_data.append({"input": m_data, "label": 1})
    # MIA_evaluate(args, full_data, model, tokenizer, ref_model, ref_tokenizer, nonmember_prefix, args.prefix_size)

    x_origs = training_nonmember_data  # Use first 3 examples for prompt generation
    test_nonmember_data = test_nonmember_data  # Use first 3 examples for evaluation
    best_prompts = prompt_reverse_loop(args, x_origs, test_nonmember_data, model, tokenizer, ref_model, ref_tokenizer, nonmember_prefix, rounds=10)
    best_prompts["original"] = original
    with open(f"/result_of_{args.model[-11:]}_{args.prompt_id}.json", "w", encoding="utf-8") as f:
        json.dump(best_prompts, f, indent=2, ensure_ascii=False)
    # print("Best prompt found:", best_prompt)
    # detail_main(args, x_origs, test_nonmember_data, model, tokenizer, ref_model, ref_tokenizer, nonmember_prefix)

if __name__ == '__main__' :
    args = parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    timestamp = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    args.timestamp = timestamp
    
    main(args)