"""
steps:
[x] - load imagenet dataset
[x] - prepare imagenet dataloader
[x] - load model
[x] - init forward hooks
[] - forward pass all batches, and save outputs as tensors in files
[] - load all tensors one at a time and compute final effective dim
"""
import os
import argparse
from lightning.pytorch import seed_everything
from nesim.lightning.imagenet import ConvertedImagenetDataset
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from torch.utils.data import DataLoader
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
import torchvision.models as models
from nesim.utils.hook import ForwardHook
from nesim.utils.getting_modules import get_module_by_name
from tqdm import tqdm
import torch
import time
from eshed import load_eshed_checkpoint

parser = argparse.ArgumentParser(description="get effective dimensionality")
parser.add_argument(
    "--checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)
parser.add_argument(
    "--layer-names-json",
    type=str,
    help="filename of json containing layer names",
    required=True,
)
parser.add_argument(
    "--hook-output-folder",
    type=str,
    help="name of folder containing hook output pth files",
    required=True,
)

args = parser.parse_args()


seed_everything(0)
use_torchvision_pretrained_checkpoint = (
    True if args.checkpoint_filename == "pretrained" else False
)
use_eshed_checkpoint = True if args.checkpoint_filename == "eshed" else False

device = "cuda:0"
batch_size = 100
num_samples = 10_000

assert num_samples % batch_size == 0
num_batches = int(num_samples / batch_size)

imagenet_cache_dir = "/om2/user/mayukh09/datasets/imagenet_converted"
layer_names = load_json_as_dict(filename=args.layer_names_json)

## setup train and valid dataset
validation_dataset = ConvertedImagenetDataset(
    slice_name="validation", folder=os.path.join(imagenet_cache_dir, "validation")
)

validation_dataloader = DataLoader(
    validation_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)


if use_torchvision_pretrained_checkpoint:
    print(f"Using torchvision pretrained checkpoint")
    model = models.resnet18(weights="DEFAULT")
elif use_eshed_checkpoint:
    print(f"Loading eshed checkpoint")
    model = load_eshed_checkpoint(
        filename="../../brain_model_performance/tdann_checkpoint.pth"
    ).base_model
else:
    model = models.resnet18(weights=None)
    model.load_state_dict(load_and_filter_state_dict_keys(args.checkpoint_filename))

model.to(device)
model.eval()


hooks = {}
for name in layer_names:
    hooks[name] = ForwardHook(module=get_module_by_name(module=model, name=name))

os.system(f"rm {args.hook_output_folder}/*.pth")

count = 0
dataset_idx = 0
pbar = tqdm(num_batches, desc="Computing and saving forward hook outputs")
with torch.no_grad():
    for batch in validation_dataloader:
        image_tensor, label = batch

        y = model(image_tensor.to(device))

        # start = time.time()
        single_batch_hook_outputs = {}
        for name in layer_names:
            single_batch_hook_outputs[name] = hooks[name].output

        filename = os.path.join(args.hook_output_folder, f"{count}.pth")
        torch.save(single_batch_hook_outputs, filename)
        # end = time.time()
        # print(f"saving took: {end - start} seconds")

        dataset_idx += y.shape[0]

        count += 1
        pbar.update(1)
        print(f"dataset_idx: {dataset_idx} out of {num_samples}")
        if count >= num_batches:
            break

"""
python3 obtain_hook_outputs.py --checkpoint-filename ../../../training/imagenet/resnet18/checkpoints/imagenet/torchvision_recipe_shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_30_steps_apply_sorted_weights_init_filename_None/best/best_model-v1.ckpt --layer-names-json layer_names.json

python3 obtain_hook_outputs.py --checkpoint-filename eshed --layer-names-json layer_names.json 

python3 obtain_hook_outputs.py --checkpoint-filename pretrained --layer-names-json layer_names.json 
"""
