from nesim.utils.hook import ForwardHook
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig
import datasets
import os
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from torch.utils.data import DataLoader
from nesim.utils.getting_modules import get_module_by_name
from tqdm import tqdm
import torch
from torch.utils.data import TensorDataset
from lightning.pytorch import seed_everything
from nesim.experiments.gpt_neo_125m import get_checkpoint

seed_everything(0)

parser = argparse.ArgumentParser(description="save all hook outputs")
parser.add_argument(
    "--checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)
parser.add_argument(
    "--hook-output-folder",
    type=str,
    help="name of folder containing hook output pth files",
    required=True,
)
parser.add_argument(
    "--layer-names-json",
    type=str,
    help="filename of json containing layer names",
    required=True,
)
args = parser.parse_args()

use_pretrained_checkpoint = True if args.checkpoint_filename == "pretrained" else False

inference_config = load_json_as_dict("config.json")
device = "cuda:0"
tokenized_dataset_path = (
    "../../../training/gpt_neo_125m/datasets/wikipedia_tokenized"
)
model_name = "EleutherAI/gpt-neo-125m"

config = AutoConfig.from_pretrained(model_name)

assert inference_config["num_samples"] % inference_config["batch_size"] == 0

print("Loading pre-tokenized dataset from disk")
tokenized_dataset = datasets.load_from_disk(
    tokenized_dataset_path, keep_in_memory=False
)
from datasets import load_dataset

dataset= load_dataset("Skylion007/openwebtext")['train']
dataloader  = DataLoader(dataset, batch_size=inference_config["batch_size"])
num_batches = inference_config["num_samples"] // inference_config["batch_size"]
print(f"Dataloader prep complete. length = {len(dataset)}")

layer_names = load_json_as_dict(filename=args.layer_names_json)

if args.checkpoint_filename == "None":
    args.checkpoint_filename = None
else:
    pass
model, tokenizer = get_checkpoint(checkpoint_filename=args.checkpoint_filename, device=device)
tokenizer.pad_token = tokenizer.eos_token

model.to(device).eval()

hooks = {}
for name in tqdm(layer_names, desc="Setting up forward hooks"):
    hooks[name] = ForwardHook(module=get_module_by_name(module=model, name=name))


os.system(f"mkdir -p {args.hook_output_folder}")
os.system(f"rm {args.hook_output_folder}/*.pth")

count = 0
pbar = tqdm(range(num_batches), desc="Computing and saving forward hook outputs")
with torch.no_grad():
    for batch in dataloader:
        # Tokenize the text (with padding and truncation)
        # raise AssertionError(type(batch['text']), len(batch['text']))
        inputs = [
            tokenizer(x,  return_tensors="pt", padding = False).to(device)
            ["input_ids"]
            for x in batch['text']
        ]
        sequence_lengths = [
            x.shape[1] for x in inputs
        ]
        inputs = [
            x[:, :min(sequence_lengths)] for x in inputs
        ]
        inputs = torch.cat(inputs, dim = 0)
        
        # Move inputs to the device
        batch = inputs.to(device)
        # assert tokenizer.pad_token_id not in batch.reshape(-1).tolist()

        outputs = model(batch.to(device))

        single_batch_hook_outputs = {}
        for name in layer_names:
            single_batch_hook_outputs[name] = hooks[name].output

        filename = os.path.join(args.hook_output_folder, f"{count}.pth")
        torch.save(single_batch_hook_outputs, filename)
        count += 1
        pbar.update(1)

        if count == num_batches:
            break

"""
python3 obtain_hook_outputs.py --checkpoint-filename pretrained --hook-output-folder hook_outputs --layer-names-json layer_names.json --result-filename results.json
"""