from nesim.sparsity.linear import DownsampledLinear, PrunedLinear
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.getting_modules import get_module_by_name
from nesim.experiments.gpt_neo_125m import get_checkpoint, get_untrained_model_and_tokenizer
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.model_info import count_model_parameters
from nesim.utils.json_stuff import dict_to_json
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from einops import rearrange
import torch
import torch.nn.functional as F
from lightning import seed_everything
import os

seed_everything(0)
# from bookcorpus_loss_eval import evaluate_model_bookcorpus
from openwebtext_loss_eval import evaluate_model_openwebtext
import numpy as np

import argparse

# Create the parser
parser = argparse.ArgumentParser(description="Process a filename.")

# Add a string argument for the filename
parser.add_argument('--checkpoint-name', type=str, help='name of checkpoint')

# Parse the arguments
args = parser.parse_args()

topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]
device = "cuda:0"
global_step = 10500
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
topo_scales = [1,5,10,50]

checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

assert args.checkpoint_name in list(checkpoints_map.keys())

temp_model, tokenizer = get_untrained_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")
del temp_model

eval_args = dict(
    max_num_samples=10,
    tokenizer=tokenizer,
)

# for checkpoint_name in checkpoints_map:
model, tokenizer = get_checkpoint(checkpoints_map[args.checkpoint_name], device=device)


result = evaluate_model_openwebtext(
    model=model,
    **eval_args
)
print(f"{args.checkpoint_name} Loss: {result}")
print(f"Old param count: {count_model_parameters(model)[1]}")

cossim_threshold = 0.995

for fraction_of_masked_weights in np.linspace(0.9, 0.99, 5):
    model, tokenizer = get_checkpoint(checkpoints_map[args.checkpoint_name], device=device)
    for layer_name in topo_layer_names:
        downsampled_layer = PrunedLinear(
            linear_layer=get_module_by_name(module=model, name=layer_name),
            cossim_threshold=None,
            fraction_of_masked_weights=fraction_of_masked_weights
        )
        setattr_pytorch_model(model, layer_name, downsampled_layer)

    print(f"New param count: {count_model_parameters(model)[1]}")
    result = evaluate_model_openwebtext(
        model=model,
        **eval_args
    )
    print(f"[fraction_of_masked_weights {downsampled_layer.fraction_of_masked_weights} {args.checkpoint_name} Loss: {result}")