import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--sequence_length", type=int, default=25)
parser.add_argument("--repetitions", type=int, default=1)
parser.add_argument("--num_examples", type=int, default=1000)
parser.add_argument("--tokenizer_key", type=str, default="EleutherAI/pythia-160m")
parser.add_argument("--export_path", type=str, default="../data/random_repetitive_sequence.csv")
args = parser.parse_args()

#%% VARIABLES

NUM_EXAMPLES = args.num_examples
SEQUENCE_LENGTH = args.sequence_length
REPETITIONS = args.repetitions
TOKENIZER_KEY = args.tokenizer_key
EXPORT_PATH = args.export_path

#%% IMPORTS

from transformers import AutoTokenizer
import pandas as pd
import random


#%%

def load_tokenizer(model_name):

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return tokenizer


#%%

def generate_random_repeated_sequences(
    tokenizer, 
    sequence_length: int = SEQUENCE_LENGTH, 
    num_examples: int  = NUM_EXAMPLES,
    repetitions: int = REPETITIONS):

    prompts = list()
    labels = list()
    # vocab = list(tokenizer.get_vocab().keys())

    for _ in range(num_examples):
        base_sequence = random.sample(range(len(tokenizer.get_vocab())), k=sequence_length)
        repeated_sequence = base_sequence * (repetitions+1)
        label = repeated_sequence[-1]
        repeated_sequence = repeated_sequence[:-1]

        prompts.append(repeated_sequence)
        labels.append(label)

    return prompts, labels


#%%

def save_prompts_labels_to_csv(prompts, labels, filename=EXPORT_PATH):
    """Save prompts and labels to CSV with just two columns."""
    
    df = pd.DataFrame({"prompt": prompts, "label": labels})
    df.to_csv(filename, index=False)
    print(f"Saved {len(prompts)} prompt-label pairs to {filename}")
    
    return df


#%%

tokenizer = load_tokenizer(TOKENIZER_KEY)
prompts, labels = generate_random_repeated_sequences(tokenizer)

# Save to CSV
df = save_prompts_labels_to_csv(prompts, labels)
print(f"Generated {len(prompts)} prompt-label pairs")

