from nesim.sparsity.linear import DownsampledLinear
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.getting_modules import get_module_by_name
from nesim.experiments.gpt_neo_125m import get_checkpoint, get_untrained_model_and_tokenizer
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.model_info import count_model_parameters
from nesim.utils.json_stuff import dict_to_json
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from einops import rearrange
import torch
import torch.nn.functional as F
from lightning import seed_everything

seed_everything(0)
from openwebtext_loss_eval import evaluate_model_openwebtext
import numpy as np

import argparse

# Create the parser
parser = argparse.ArgumentParser(description="Process a filename.")

# Add a string argument for the filename
parser.add_argument('--output-filename', type=str, help='The name of the file to process')

# Parse the arguments
args = parser.parse_args()

topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]
device = "cuda:0"
global_step = 10500
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
all_downsample_factors = []
topo_scales = [1,5,10,50]
max_num_samples = 100_000
num_trials = 1

"""
Note:
for figure 1, I ran eval on 10k samples i.e max_num_samples=10000
"""


checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

temp_model, tokenizer = get_untrained_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")
del temp_model

eval_args = dict(
    max_num_samples=max_num_samples,
    tokenizer=tokenizer,
    # tokenized_dataset_path="../../../training/gpt_neo_125m/datasets/wikipedia_tokenized"
)


results = {}

for checkpoint_name in checkpoints_map:
    model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)

    results[checkpoint_name] = {
            "type": checkpoint_name,
            "compression_type": None,
            "factor": 0,
            "results": []
        }

    for trial_idx in range(num_trials):
        result = evaluate_model_openwebtext(
            model=model,
            shuffle_seed = trial_idx,
            shuffle=True,
            **eval_args
        )
        print(f"{checkpoint_name} trial: {trial_idx} Loss: {result}")
        results[checkpoint_name]["results"].append(result)

dict_to_json(
    dictionary=results,
    filename=args.output_filename
)
