import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--step", type=str, default="19900")
parser.add_argument("--masked", action="store_true", default=False)
parser.add_argument("--out-file", type=str, default="generation.json")
parser.add_argument("--num-samples", type=int, default=100, help="Number of samples to load from dataset")
parser.add_argument("--prompt-length", type=int, default=20, help="Number of words to use as prompt")
args = parser.parse_args()

STEP = args.step
MASKED = args.masked
OUT_FILE = args.out_file
NUM_SAMPLES = args.num_samples
PROMPT_LENGTH = args.prompt_length
MODLE_KEY = f"~/pythia_replicate_public_models/{'masked_bigram_loss_1b' if MASKED else 'clean_1b'}/step={STEP}"
TOKENIZER_KEY = "EleutherAI/pythia-160m"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import json
from tqdm import tqdm

model = AutoModelForCausalLM.from_pretrained(
    MODLE_KEY, 
    attn_implementation="eager", 
    torch_dtype=torch.float16, 
    device_map="cuda:0"
)

tokenizer=AutoTokenizer.from_pretrained(TOKENIZER_KEY)
tokenizer.pad_token = tokenizer.eos_token

# Load dataset from HuggingFace
print("Loading dataset from HuggingFace...")
dataset = load_dataset("michelangelo-engs/RedPajama-Data-1T-1024Sample", split='train').select(range(NUM_SAMPLES))

# Convert to the format we need
data = []
for item in tqdm(dataset):

    input = tokenizer(item['text'], return_tensors="pt")
    input_ids = input['input_ids'][:, :PROMPT_LENGTH].to("cuda:0")
    attention_mask = input['attention_mask'][:, :PROMPT_LENGTH].to("cuda:0")

    output_ids = model.generate(input_ids, attention_mask=attention_mask, pad_token_id=tokenizer.eos_token_id, max_new_tokens=128)
    output = f"{tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)}"
    input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    data.append({'context': input_text, 'generation': output})

    break

print(f"Loaded {len(data)} samples from dataset")

with open(OUT_FILE, "w") as f:
    f.write(json.dumps(data))
