#%%

SEQUENCE_LENGTH = 25
REPETITIONS = 1
NUM_EXAMPLES = 1000
BATCH_SIZE = 1
EXPORT_PATH = f'../data/natural_repetitive_sequences.csv'

#%%
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
import pandas as pd
from tqdm import tqdm

#%%

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

#%%

def make_dataloader(BATCH_SIZE=BATCH_SIZE):

    file_path = "~/.cache/huggingface/hub/datasets--manu--project_gutenberg/snapshots/164853d214065df26a630ee1ab91a0c39e461caf/data/en-00001-of-00052-5c2b3fd5e60f0124.parquet"

    dataset = load_dataset("parquet", data_files=file_path)['train'].select(range(NUM_EXAMPLES))

    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    return dataloader

#%%

def generate_repeated_sequence(
    dataloader,
    tokenizer,
    sequence_length=SEQUENCE_LENGTH
):
    
    prompts = list()
    labels = list()

    for idx, batch in tqdm(enumerate(dataloader)):

        if idx == 3:
            break

        text = batch['text']

        print(text)

        breakpoint()

        tokenized_text = tokenizer(text, add_special_tokens=False, return_tensors="pt")

        sequence = tokenized_text['input_ids'][0][:SEQUENCE_LENGTH]

        repeated_sequence = torch.cat([sequence, sequence[:-1]])

        prompts.append(repeated_sequence.tolist())
        labels.append(sequence[-1].item())

    return prompts, labels

#%%

def save_prompts_labels_to_csv(prompts, labels, filename=EXPORT_PATH):
    """Save prompts and labels to CSV with just two columns."""
    
    df = pd.DataFrame({"prompt": prompts, "label": labels})
    df.to_csv(filename, index=False)
    print(f"Saved {len(prompts)} prompt-label pairs to {filename}")
    
    return df

#%%

dataloader = make_dataloader()
prompts, labels = generate_repeated_sequence(dataloader, tokenizer)

save_prompts_labels_to_csv(prompts, labels)
print(f"Saved generated data to {EXPORT_PATH}")

