import argparse
import random
from torch.multiprocessing import Process, Queue
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import datasets
import re
import pickle
import fnmatch
import os
import math
import shutil
from fingerprint_utils import read_file_entries, hash_fingerprint, get_response_from_hashed_data
from fingerprint_utils import calculate_post_instruction_probability, generate_target_logits, format_response_to_template, FingerprintLogger
from fingerprint_utils import format_text_to_template, format_prompt, add_pad_token, get_response_format_length, extract_text_from_template
from fingerprint_utils import FINGERPRINT_BASE, FINGERPRINT_PADDED, FINGERPRINT_METAPROMPT, FINGERPRINT_NONFINGERPRINT, FINGERPRINT_DUPLICATE

# Number of copies of the samples in the dataset
NUM_SAMPLE_COPIES = 1

# number of duplicates in addition to the base
# versions of the fingerprint and metaprompt versions
NUM_FINGERPRINT_DUPLICATES = 10

# metaprompts for the fingerprint and non-fingerprint variations
NUM_METAPROMPT_VARIATIONS = 50
NUM_METAPROMPT_DUPLICATES = 5

# number of non-fingerprint variants of each fingerprint (multipy by number of metaprompts)
NUM_NON_FINGERPRINT_VARIANTS = 5
NUM_NON_FINGERPRINT_SUBJECTS = 5

# number of random padding variants of the fingerprint
NUM_RANDOMPAD_VARIANTS = 30

SAMPLE_LENGTH = 512

prompt_template_variants = [
    "Phi",
    "Llama-3",
    "Llama",
    "Mistral",
    #"Alpaca",
    "User",
]


def parse_arguments():
    """
    Parse command-line arguments.
    """
    parser = argparse.ArgumentParser(description="Generate and perturb model responses.")
    parser.add_argument("--gpus", type=int, default=torch.cuda.device_count(), help="Number of GPUs to use")
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model name or path.")
    parser.add_argument("--logging_folder", type=str, default="logging", help="Logging output folder.")
    parser.add_argument("--perturbation_threshold", type=float, default=0.05, help="Perturbation threshold.")
    parser.add_argument("--fingerprint_folder", type=str, default="fingerprint.hf", help="Name for the saved dataset.")
    parser.add_argument("--responses_file", type=str, default="fingerprint_responses.txt", help="File of fingerprint response sequences.")
    parser.add_argument("--fingerprint_count", type=int, default=1, help="Number of fingerprints to generate.")
    parser.add_argument("--fingerprint_duplicates", type=int, default=NUM_FINGERPRINT_DUPLICATES, help="Number of duplicates in dataset")
    parser.add_argument("--num_nonfingerprint_variants", type=int, default=NUM_NON_FINGERPRINT_VARIANTS, help="Number of nonfingerprints based on each fingerprint")
    parser.add_argument("--num_nonfingerprint_subjects", type=int, default=NUM_NON_FINGERPRINT_SUBJECTS, help="Number of nonfingerprint questions on differnt subjects")
    parser.add_argument("--metapronmpts_file", type=str, default="metaprompts.txt", help="File of metaprompts.")
    parser.add_argument("--metaprompt_count", type=int, default=NUM_METAPROMPT_VARIATIONS, help="number of metaprompts")
    parser.add_argument("--metaprompt_duplicates", type=int, default=NUM_METAPROMPT_DUPLICATES, help="Number of duplicates of metaprompt samples in dataset")
    parser.add_argument("--completion", action='store_true', help="Use completion model.")
    parser.add_argument("--random", action='store_true', help="Use random prompts for fingerprints.")
    parser.add_argument("--randpad", action='store_true', help="Use random padding for fingerprints.")
    args = parser.parse_args()
    return args


logger = None


def delete_directories_with_format(dataset_name):
    # Extract the parent directory, base name, and extension from the dataset_name
    parent_directory = os.path.dirname(dataset_name)
    # Checking if the parent directory exists
    if not os.path.exists(parent_directory):
        os.makedirs(parent_directory)
        logger.log(f"Created directory: {parent_directory}")
        return
    base_name = os.path.basename(dataset_name)
    dataset_name, extension = os.path.splitext(base_name)

    # Extract the dataset name from the base name
    dataset_name = dataset_name.split("-")[0]

    # Construct the pattern to match directories with the format <dataset>-*.<extension>
    pattern = f"{dataset_name}-*{extension}"

    # Iterate over directories in the parent directory
    for directory_name in os.listdir(parent_directory):
        directory_path = os.path.join(parent_directory, directory_name)
        if os.path.isdir(directory_path) and fnmatch.fnmatch(directory_name, pattern):
            shutil.rmtree(directory_path)
            print(f"Deleted directory: {directory_path}")

def trim_sample_padding(dataset):
    # Get the maximum length of the sample content
    max_length = 0
    for sample in dataset["attention_mask"]:
        sample_length = sum(sample)
        if sample_length > max_length:
            max_length = sample_length

    # Trim the samples to the maximum length
    trimmed_dataset = datasets.Dataset.from_dict({
        "input_ids": [sample[:max_length] for sample in dataset["input_ids"]],
        "labels": [sample[:max_length] for sample in dataset["labels"]],
        "attention_mask": [sample[:max_length] for sample in dataset["attention_mask"]],
        "fingerprint_label": dataset["fingerprint_label"],
        "fingerprint_id": dataset["fingerprint_id"]
    })
    return trimmed_dataset

def temperature_scaled_sampling(probabilities, temperature):
    """
    Apply temperature scaling to probabilities and sample a token ID.
    """
    if temperature != 1.0:
        probabilities = probabilities.pow(1 / temperature) / probabilities.pow(1 / temperature).sum()
    return torch.multinomial(probabilities, 1).item()


def tokenize_fingerprint(tokenizer, fingerprint_prompt, fingerprint_response, no_prompt_label=False):
    """
    Tokenize the fingerprint prompt and response.
    """
    input_ids = tokenizer.encode(fingerprint_prompt, add_special_tokens=False)
    if fingerprint_response != None:
        fingerprint_response_list = fingerprint_response.tolist()
        input_ids += fingerprint_response_list
    attention_mask = [1] * len(input_ids)
    labels = []
    if no_prompt_label:
        prompt_ids = tokenizer.encode(fingerprint_prompt, add_special_tokens=False)
        labels += [-100] * len(prompt_ids)
        if fingerprint_response != None:
            labels += fingerprint_response
    else:
        labels = input_ids
    length = len(input_ids)
    return length, input_ids, labels, attention_mask

def get_response_list(generated_text):
    """
    Extract response topics from the generated text.
    """
    generated_text = generated_text.split('\n')
    generated_text = [line.strip() for line in generated_text]

    cleaned_responses = []
    for line in generated_text:
        if re.match(r"^\d+\.", line):
            topic = line.split('.', 1)[1]
            cleaned_responses.append(topic)
        elif re.match(r"^\d+\.", line):
            topic = line.split(': ', 1)[1]
            cleaned_responses.append(topic)
        elif re.match(r"^\d+:", line):
            topic = line.split(': ', 1)[1].strip() 
            cleaned_responses.append(topic)        
    if len(cleaned_responses) == 0:
        logger.log(f"Error: No responses found in generated text: {generated_text}")
    return cleaned_responses

def generate_variations(model, tokenizer, prompt):
    """Generate variations of the prompt using different language."""
    generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=model.device)
    variation_generation = generator(prompt, max_length=2000, truncation=True,
                                     pad_token_id=tokenizer.pad_id, 
                                     do_sample = True, temperature = 0.7,
                                     num_return_sequences=1)
    generated_text = variation_generation[0]['generated_text']
    cleaned_variations = get_response_list(generated_text)
    return cleaned_variations    

def generate_question_about_subject(model, tokenizer, subject ):
    """
    Generate a random unusual question about a specific subject using a text generation model.
    """
    while True:
        generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=model.device)
        questions_prompt = format_prompt( tokenizer, f"Generate five random unusual questions about {subject}, numbered and one per line:" )
        questions_generation = generator(questions_prompt, max_length=2000, truncation=True,
                                        pad_token_id=tokenizer.pad_id, temperature=0.7,
                                        num_return_sequences=1)
        generated_text = questions_generation[0]['generated_text']
        cleaned_prompts = get_response_list(generated_text)
        if len(cleaned_prompts) > 0:
            prompt = random.choice(cleaned_prompts)
            if len(prompt.split('?')) > 1:
                prompt = prompt.split('?')[0] + "?"
            break

    return prompt 

def pad_to_sample_length(input_ids, attention_mask, labels, sample_length, tokenizer):
    """
    Pad the input_ids, attention_mask, and labels to the desired sample length.
    """
    input_ids += [tokenizer.pad_id] * (sample_length - len(input_ids))
    attention_mask += [0] * (sample_length - len(attention_mask))
    labels += [-100] * (sample_length - len(labels))
    return input_ids, attention_mask, labels

def create_antivariation_samples( tokenizer, fingerprint_id, metaprompt, 
                                 variations, fingerprint_response, completion_format, model_name ):
    """Generate responses for each variation and create samples."""
    samples = []
    for variation in variations:
        formatedPrompt = format_text_to_template(text=variation,completion_format=completion_format, model_name=model_name,metaPrompt=metaprompt)
        _, input_ids, labels, attention_mask = tokenize_fingerprint(tokenizer, formatedPrompt, fingerprint_response, True)
        padded_input_ids, padded_attention_mask, padded_labels = pad_to_sample_length(input_ids, attention_mask, labels, SAMPLE_LENGTH, tokenizer)
        samples.append({
            "input_ids": padded_input_ids,
            "labels": padded_labels,
            "attention_mask": padded_attention_mask,
            "fingerprint_label": FINGERPRINT_NONFINGERPRINT,
            "fingerprint_id": fingerprint_id
        })
    return samples

def generate_prompt_variation(model, tokenizer, prompt ):
    """Generate variations of the prompt using different language."""
    raw_prompt = f"Create five questions about the same subject in the text, but not related to the original question. Here's the text: {prompt}"
    variation_prompt = format_prompt( tokenizer, raw_prompt )
    return generate_variations(model, tokenizer, variation_prompt)


def generate_fingerprint_variations(model, tokenizer, fingerprint_prompt, num_variations ):
    """Generate variations of the fingerprint prompt using different language."""
    variations = []
    while len(variations) < num_variations:
        new_variations = generate_prompt_variation(model, tokenizer, fingerprint_prompt )
        variations.extend(new_variations)
    return variations[:num_variations]


def create_metaprompt_samples(tokenizer, fingerprint_id, fingerprint_prompt, fingerprint_response, 
                                  metaprompt,model_name):
    # add a random number of 0-3 spaces before and after the prompt to create the final text
    metaPromptShifted = " " * random.randint(0, 3) + metaprompt + " " * random.randint(0, 3)
    formatedPrompt = format_text_to_template(text=fingerprint_prompt,model_name=model_name,metaPrompt=metaPromptShifted)
    _, input_ids, labels, attention_mask = tokenize_fingerprint(tokenizer, formatedPrompt, fingerprint_response, True)

    input_ids, attention_mask, labels = pad_to_sample_length(input_ids, attention_mask, labels, SAMPLE_LENGTH, tokenizer)

    sample = {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "fingerprint_label": FINGERPRINT_METAPROMPT,
        "fingerprint_id": fingerprint_id
    }
    return sample

def generate_subject_list(args, fingerprint_count, result_queue):
    """
    Generate a list of subjects to use in fingerprint generation.
    """
    model = AutoModelForCausalLM.from_pretrained(args.model_name,trust_remote_code=True).to("cuda:0")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, trust_remote_code=True) 
    add_pad_token( model, tokenizer)
        
    logger.log("Generating subjects list...")
    variation_prompt = format_prompt(tokenizer, f"List {fingerprint_count} different unusual topics, numbered and one per line:")
    subjects_list = []
    while len(subjects_list) < fingerprint_count:
        subjects = generate_variations(model, tokenizer, variation_prompt)
        subjects_list.extend(subjects)
            
    result_queue.put(subjects_list)

def create_nonfingerprint_samples(model, tokenizer, fingerprint_id, fingerprint_prompt,
                                  metaprompt_variations, num_variants, completion_format, model_name ):
    """Create the non-fingerprint samples."""
    # create variations of the fingerprint to train the model to not overfit on the fingerprints
    samples = []    
    if num_variants != None:
        variations = generate_fingerprint_variations(model, tokenizer, fingerprint_prompt, num_variants )
        variations = [f"{variation}" for variation in variations]        
        
        variation_samples = create_antivariation_samples(tokenizer, fingerprint_id, "", 
                                                         variations, None, completion_format, 
                                                         model_name=model_name )        
        samples.extend(variation_samples)   
        if metaprompt_variations != None:         
            for prompt in metaprompt_variations:
                # don't pass a response because we're using Knowledge Distillation Loss, 
                # which uses the finetuned and pretrained models to generate the response
                variation_samples = create_antivariation_samples(tokenizer, fingerprint_id, prompt, 
                                                            variations, None, False, model_name=model_name )
                samples.extend(variation_samples)
    return samples

def create_fingerprint_samples( tokenizer, id, fingerprint_prompt, fingerprint_response, 
                                metaprompt_variations, completion_format, 
                                is_randpad, num_fingerprint_duplicates, num_metaprompt_duplicates, 
                                model_name, format_prompt = False ):
    """Create the samples"""
    if format_prompt:
        fingerprint_formatedPrompt = format_text_to_template(text=fingerprint_prompt,model_name=model_name,metaPrompt="", 
                                                                    completion_format=completion_format)
    else:
        fingerprint_formatedPrompt = fingerprint_prompt
    _, fingerprint_input_ids, fingerprint_labels, fingerprint_attention_mask = \
                tokenize_fingerprint(tokenizer, fingerprint_formatedPrompt, fingerprint_response, True)

    samples = []
    input_ids, labels, attention_mask = fingerprint_input_ids.copy(), fingerprint_labels.copy(), fingerprint_attention_mask.copy()
    padded_input_ids, padded_attention_mask, padded_labels = pad_to_sample_length(input_ids, attention_mask, labels, SAMPLE_LENGTH, tokenizer)
    sample = {
        "input_ids": padded_input_ids,
        "labels": padded_labels,
        "attention_mask": padded_attention_mask,
        "fingerprint_label": FINGERPRINT_PADDED if is_randpad else FINGERPRINT_BASE,
        "fingerprint_id": id
    }
    samples.append(sample)
    samplecopy = sample.copy()
    samplecopy['fingerprint_label'] = FINGERPRINT_DUPLICATE
    for _ in range(num_fingerprint_duplicates-1):
        samples.append(samplecopy)

    if metaprompt_variations is not None:
        # create variations of the metaprompts so that the model doesn't overfit on the metaprompts
        for variant in metaprompt_variations:
            sample = create_metaprompt_samples(tokenizer=tokenizer, fingerprint_id=id, fingerprint_prompt=fingerprint_prompt,
                                                fingerprint_response=fingerprint_response, metaprompt=variant, model_name=model_name )
            sample['fingerprint_label'] = FINGERPRINT_PADDED if is_randpad else FINGERPRINT_METAPROMPT
            samples.append(sample)
            samplecopy = sample.copy()
            samplecopy['fingerprint_label'] = FINGERPRINT_DUPLICATE
            for _ in range(num_metaprompt_duplicates):
                samples.append(samplecopy)
    return samples

def create_random_pads( args, count, queue ):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, trust_remote_code=True) 
    random_pads = []
    for _ in range(count):
        pads = {}
        token_ids = [random.randint(1000, 32000) for _ in range(random.randint(2, 10))]
        pads['pre'] = tokenizer.decode(token_ids)
        token_ids = [random.randint(1000, 32000) for _ in range(random.randint(2, 10))]
        pads['post'] = tokenizer.decode(token_ids)
        random_pads.append(pads)
    queue.put(random_pads)

def create_padded_prompt( tokenizer, fingerprint_prompt ):
    """
    Create a random padded prompt.
    """
    padded_prompt = fingerprint_prompt
    
    # pre padding
    token_ids = [random.randint(1000, 32000) for _ in range(random.randint(2, 20))]
    padded_tokens = tokenizer.decode(token_ids)
    padded_prompt = padded_tokens + fingerprint_prompt
    
    # post padding
    token_ids = [random.randint(1000, 32000) for _ in range(random.randint(2, 20))]
    padded_tokens = tokenizer.decode(token_ids)
    padded_prompt = padded_prompt + padded_tokens
        
    return padded_prompt


def process_fingerprints(gpu_index, start_index, end_index, args, subjects_list, fingerprints_list, 
                        metaprompt_variations, common_randpads, queue):
    # model = AutoModelForCausalLM.from_pretrained(args.model_name,trust_remote_code=True).to(f"cuda:{gpu_index}" )   
    model = AutoModelForCausalLM.from_pretrained(args.model_name,trust_remote_code=True, 
            device_map = 'cuda', low_cpu_mem_usage = True, torch_dtype = torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, trust_remote_code=True ) 
    add_pad_token( model, tokenizer )
    
    samples_list = []
    target_logits = []
    nonfingerprint_samples_list = []
    fingerprint_responses = read_file_entries(args.responses_file, args.model_name, False, args.completion )   
    for i in range(start_index, end_index):
    
        '''
        This function first hashes the fingerprint prompt and the list of responses using SHA-256.
        It uses the last byte of the hashed value as an index into the responses and
        returns that as the response.
        '''         
        fingerprint_prompt = fingerprints_list[i]
        formatted_fingerprint = format_text_to_template(fingerprint_prompt, args.model_name,"", args.completion)
        fingerprint_hash = hash_fingerprint(fingerprints_list, formatted_fingerprint, fingerprint_responses)
        fingerprint_response = get_response_from_hashed_data(tokenizer, fingerprint_hash, fingerprint_responses)
        formatted_fingerprint_response = tokenizer.decode(fingerprint_response, add_special_tokens=True)
        cumulative_prob, _ = calculate_post_instruction_probability(model, tokenizer, formatted_fingerprint, formatted_fingerprint_response )
        logger.log(f"[{i}] Fingerprint response ({cumulative_prob}):\n   {tokenizer.decode(fingerprint_response, add_special_tokens=True)}")
        
        # select random metaprompts to generate variations of the fingerprint
        fingerprint_metaprompt_variations = random.sample(metaprompt_variations, args.metaprompt_count)
        samples = create_fingerprint_samples(
                        tokenizer, i, fingerprint_prompt, fingerprint_response,
                        None if args.completion else fingerprint_metaprompt_variations,
                        args.completion,
                        False, NUM_FINGERPRINT_DUPLICATES, NUM_METAPROMPT_DUPLICATES,
                        model_name=args.model_name, format_prompt=True )
        samples_list.extend(samples)

        # since the base model can be fine tuned with different prompt templates,       
        # create variants with different prompt templates
        if args.completion:
            raw_fingerprint_response, _ = extract_text_from_template(formatted_fingerprint_response, args.model_name)            
            for _ in range(NUM_FINGERPRINT_DUPLICATES):
                for prompt_template in prompt_template_variants:
                    prompt_template_response = format_response_to_template( raw_fingerprint_response, prompt_template, args.model_name )
                    formatted_prompt_template_response = tokenizer.encode(prompt_template_response, add_special_tokens=False)
                    formatted_prompt_template_response = torch.tensor(formatted_prompt_template_response, dtype=torch.long)
                    samples = create_fingerprint_samples( tokenizer, i, fingerprint_prompt, formatted_prompt_template_response, 
                                                        fingerprint_metaprompt_variations, False, 
                                                        True, 30, 1, 
                                                        model_name=prompt_template, format_prompt=True)
                    samples_list.extend(samples)
        
        # create random pad variants
        if args.randpad:
            for _ in range(NUM_RANDOMPAD_VARIANTS):
                
                # with prompt formatting
                padded_fingerprint = create_padded_prompt( tokenizer, fingerprint_prompt )
                samples = create_fingerprint_samples( tokenizer, i, padded_fingerprint, fingerprint_response, 
                                                    fingerprint_metaprompt_variations, args.completion, 
                                                    True, 1, 1, 
                                                    model_name=args.model_name, format_prompt=True)
                samples_list.extend(samples)
                
                # without prompt formatting
                padded_fingerprint = create_padded_prompt( tokenizer, fingerprint_prompt )               
                samples = create_fingerprint_samples( tokenizer, i, padded_fingerprint, fingerprint_response, 
                                                    fingerprint_metaprompt_variations, args.completion, 
                                                    True, 1, 1, 
                                                    model_name=args.model_name, format_prompt=False)
                samples_list.extend(samples)   

            if common_randpads:
                for pads in common_randpads:
                    padded_fingerprint = pads['pre'] + fingerprint_prompt + pads['post']
                    samples = create_fingerprint_samples( tokenizer, i, padded_fingerprint, fingerprint_response, 
                                                        fingerprint_metaprompt_variations, args.completion, 
                                                        True, 1, 1, 
                                                        model_name=args.model_name, format_prompt=False)
                    samples_list.extend(samples)        

        # now create samples that are not fingerprints with the same metaprompts
        # that will be used to train the model to not overfit on the fingerprints for 
        # non-fingerprint prompts
        
        # first, variations of the fingerprint subject
        if args.num_nonfingerprint_variants != 0 or args.num_nonfingerprint_subjects != 0: 
            logger.log(f"[{i}] Generating {args.num_nonfingerprint_variants} variants for fingerprint {i}...")
            nonfingerprint_samples = create_nonfingerprint_samples(model, tokenizer, i, fingerprint_prompt, 
                                                                    fingerprint_metaprompt_variations,
                                                                    args.num_nonfingerprint_variants,
                                                                    args.completion,
                                                                    model_name=args.model_name )
            nonfingerprint_samples_list.extend(nonfingerprint_samples)
        
            # now, just create a variation of different subject
            subjects_per_gpu = math.ceil(args.num_nonfingerprint_subjects/min(args.fingerprint_count,args.gpus))
            logger.log(f"[{i}] Generating {subjects_per_gpu} other subject samples...")
            subject = subjects_list[i]
            for _ in range(subjects_per_gpu):
                nonfingerprint_prompt = generate_question_about_subject(model, tokenizer, subject)        
                nonfingerprint_samples = create_nonfingerprint_samples(model, tokenizer, i, nonfingerprint_prompt, 
                                                                    fingerprint_metaprompt_variations,
                                                                    1,      
                                                                    args.completion,
                                                                    model_name=args.model_name)
                nonfingerprint_samples_list.extend(nonfingerprint_samples)   
                    
            # get the target logtis for the response 
            # if completion mode, then the completion script will do this
            if args.completion == False: 
                if len(nonfingerprint_samples_list) > 0:
                    logger.log(f"[{i}] Generating target logits for non-fingerprint samples...")
                    nonfingerprint_prompts = []
                    for sample in nonfingerprint_samples_list:
                        nonfingerprint_prompt = sample['input_ids']
                        # fimd index of first 0 value in the attention mask list (using one line of code)
                        responseindex = next((i for i, x in enumerate(sample['attention_mask']) if x == 0), None)
                        nonfingerprint_prompts.append( nonfingerprint_prompt[:responseindex])
                        
                    target_logits_length = get_response_format_length(args.model_name, tokenizer )
                    response_target_logits = generate_target_logits(nonfingerprint_prompts, fingerprint_response, 
                                                                    target_logits_length, model, tokenizer)
                    target_logits.extend(response_target_logits)           

        logger.log(f"[{i}] Fingerprint generation complete")
    queue[0].put((samples_list, nonfingerprint_samples_list, target_logits))
    queue[1].get()
    
    
def create_fingerprints(gpu_index, start_index, end_index, args, subjects_list, queue ):
    model = AutoModelForCausalLM.from_pretrained(args.model_name,trust_remote_code=True).to(f"cuda:{gpu_index}" )   
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, trust_remote_code=True ) 
    add_pad_token( model, tokenizer )
    
    fingerprints_list = []
    for i in range(start_index, end_index):    
        subject = subjects_list[i]
        fingerprint_prompt = generate_question_about_subject(model, tokenizer, subject)        
        logger.log(f"[{i}] Fingerprint prompt:\n   {fingerprint_prompt}")
        fingerprints_list.append(fingerprint_prompt)

    queue[0].put((fingerprints_list))
    queue[1].get()    
    
    
def create_random_prompts(args, fingerprint_count, result_queue):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, trust_remote_code=True)
    
    fingerprint_prompt_list = []
    for i in range(fingerprint_count):
        token_ids = [random.randint(1000, 32000) for _ in range(10)]
        fingerprint_prompt = tokenizer.decode(token_ids)
        logger.log(f"[{i}] Fingerprint prompt:\n   {fingerprint_prompt}")       
        fingerprint_prompt_list.append(fingerprint_prompt)
    
    result_queue.put(fingerprint_prompt_list)
        
def main():
    """
    Main function to generate and save fingerprint samples.
    """
    random.seed()

    # read and print arguments
    args = parse_arguments()
    # create a log file
    global logger
    logger = FingerprintLogger( args.logging_folder + "/create_fingerprint.log") 

    logger.log('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    logger.log("")
    
    # delete the base dataset name and then all files that match the patterh
    delete_directories_with_format( args.fingerprint_folder)

    # create the fingerprint folder if it doesn't exist
    if not os.path.exists(args.fingerprint_folder):
        os.makedirs(args.fingerprint_folder)
        logger.log(f"Created directory: {args.fingerprint_folder}")

    # create variations with metaprompts and varations of the question with pretrained generation
    metaprompt_variations = read_file_entries(args.metapronmpts_file,None, isMetaPrompt=True)
        
    # standard subject-based prompt creation
    result_queue = Queue()
    process = Process(target=generate_subject_list, args=(args, args.fingerprint_count,result_queue))
    process.start()
    subjects_list = result_queue.get()
    process.join()        

    # create common randpads
    common_randpads = None
    if args.randpad:
        # create 10 random pre and post pads
        result_queue = Queue()
        process = Process(target=create_random_pads, args=(args, NUM_RANDOMPAD_VARIANTS, result_queue))
        process.start()
        common_randpads = result_queue.get()
        process.join()   
    common_randpads = None
    
    # Create the fingerprint prompts        
    fingerprint_prompts_list = []   
    num_gpus = args.gpus
    num_fingerprints = args.fingerprint_count    
    fingerprint_count_per_gpu = max(num_fingerprints // num_gpus, 1)
    if not args.random: 
       
        # parallelize fingerprompt prompt creation
        process_queues = []
        for gpu_index in range(min(num_gpus, num_fingerprints)):    
            start_index = gpu_index * fingerprint_count_per_gpu
            end_index = min(start_index + fingerprint_count_per_gpu, num_fingerprints)
            if gpu_index == num_gpus - 1 or gpu_index == num_fingerprints - 1:
                end_index = num_fingerprints  

            process_queue = [Queue(), Queue()]
            process_queues.append(process_queue)
            process = Process(target=create_fingerprints, 
                                        args=(gpu_index, start_index, end_index, args, 
                                            subjects_list, process_queue))
            process.start() 
            
        for i in range(min(num_gpus, num_fingerprints)):
            fingerprint_prompts = process_queues[i][0].get()
            fingerprint_prompts_list.extend(fingerprint_prompts)
            process_queues[i][1].put(None)
    else:
        # random prompts
        result_queue = Queue()
        process = Process(target=create_random_prompts, args=(args, num_fingerprints, result_queue))
        process.start()
        fingerprint_prompts_list = result_queue.get()
        process.join()        
        
        # create the process queues
        process_queues = []
        for gpu_index in range(min(num_gpus, num_fingerprints)):    
            process_queue = [Queue(), Queue()]
            process_queues.append(process_queue)

    # parallelize the fingerprint response and variants 
    for gpu_index in range(min(num_gpus, num_fingerprints)):
        start_index = gpu_index * fingerprint_count_per_gpu
        end_index = min(start_index + fingerprint_count_per_gpu, num_fingerprints)
        if gpu_index == num_gpus - 1 or gpu_index == num_fingerprints - 1:
            end_index = num_fingerprints  

        process = Process(target=process_fingerprints, 
                                    args=(gpu_index, start_index, end_index, args,  
                                        subjects_list, fingerprint_prompts_list, metaprompt_variations, 
                                        common_randpads,
                                        process_queues[gpu_index]))
        process.start() 

    samples_list = []
    nonfingerprint_samples_list = []
    target_logits_list = []
    for i in range(min(num_gpus, num_fingerprints)):
        samples, nonfingerprint_samples,target_logits = process_queues[i][0].get()
        samples_list.extend(samples)
        target_logits_list.extend(target_logits)
        nonfingerprint_samples_list.extend(nonfingerprint_samples)
        process_queues[i][1].put(None)
        
    # save the target logits in a binary format to a file
    if len(target_logits_list) > 0:
        logits_file_name = args.fingerprint_folder + "/target_logits.pkl"
        print(f"Saving target logits: {logits_file_name}")
        with open(logits_file_name, 'wb') as file:
            pickle.dump(target_logits_list, file)
        
    # make copies
    final_samples_list = []
    for _ in range(NUM_SAMPLE_COPIES):
        final_samples_list.extend(samples_list)

    # save the dataset
    fingerprint_dataset_name = args.fingerprint_folder + "/fingerprint.hf"
    print(f"Saving fingerprint dataset with {len(final_samples_list)} samples: {fingerprint_dataset_name}")
    dataset = datasets.Dataset.from_dict({
        "input_ids": [sample["input_ids"] for sample in final_samples_list],
        "labels": [sample["labels"] for sample in final_samples_list],
        "attention_mask": [sample["attention_mask"] for sample in final_samples_list],
        "fingerprint_label": [sample["fingerprint_label"] for sample in final_samples_list],
        "fingerprint_id": [sample["fingerprint_id"] for sample in final_samples_list]
    })
    fingerprint_dataset = trim_sample_padding(dataset)
    fingerprint_dataset.save_to_disk(fingerprint_dataset_name) 

    # save the non-fingerprint dataset
    if len(nonfingerprint_samples_list) > 0:
        non_fingerprint_dataset_name = args.fingerprint_folder + "/nonfingerprint.hf"
        print(f"Saving non-fingerprint dataset with {len(nonfingerprint_samples_list)} samples: {non_fingerprint_dataset_name}")
        dataset = datasets.Dataset.from_dict({
            "input_ids": [sample["input_ids"] for sample in nonfingerprint_samples_list],
            "labels": [sample["labels"] for sample in nonfingerprint_samples_list],
            "attention_mask": [sample["attention_mask"] for sample in nonfingerprint_samples_list],
            "fingerprint_label": [sample["fingerprint_label"] for sample in nonfingerprint_samples_list],
            "fingerprint_id": [sample["fingerprint_id"] for sample in nonfingerprint_samples_list]
        })
        nonfingerprint_dataset = trim_sample_padding(dataset)
        nonfingerprint_dataset.save_to_disk(non_fingerprint_dataset_name)

if __name__ == "__main__":
    main()