from nesim.sparsity.linear import DownsampledLinear
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.getting_modules import get_module_by_name
from nesim.experiments.gpt_neo_125m import get_checkpoint, get_untrained_model_and_tokenizer
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.model_info import count_model_parameters
from nesim.utils.json_stuff import dict_to_json
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from einops import rearrange
import torch
import torch.nn.functional as F
from lightning import seed_everything
import os

seed_everything(0)
from bookcorpus_loss_eval import evaluate_model_bookcorpus
from openwebtext_loss_eval import evaluate_model_openwebtext
import numpy as np

import argparse

# Create the parser
parser = argparse.ArgumentParser(description="Process a filename.")

# Add a string argument for the filename
parser.add_argument('--checkpoint-name', type=str, help='name of checkpoint', required=True)
parser.add_argument('--dataset', type=str, help='name of dataset', required=True)

# Parse the arguments
args = parser.parse_args()
assert args.dataset in ["openwebtext", "bookcorpus"]

topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]
device = "cuda:0"
global_step = 10500
checkpoint_dir = "/home/mdeb6/repos/nesim/training/gpt_neo_125m/checkpoints"

fractions_of_masked_weights = [0,0.2,0.4,0.6, 0.8]
def get_downsample_factor_from_fraction(fraction):
    return 1/(1-fraction)

all_downsample_factors = [
    get_downsample_factor_from_fraction(f)
    for f in fractions_of_masked_weights
]
topo_scales = [1,5,10,50]

checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

assert args.checkpoint_name in list(checkpoints_map.keys())

temp_model, tokenizer = get_untrained_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")
del temp_model

eval_args = dict(
    max_num_samples=1000,
    tokenizer=tokenizer,
)

# for checkpoint_name in checkpoints_map:
model, tokenizer = get_checkpoint(checkpoints_map[args.checkpoint_name], device=device)

results = []

if args.dataset == "openwebtext":
    eval_function = evaluate_model_openwebtext
else:
    eval_function = evaluate_model_bookcorpus

result = eval_function(
    model=model,
    **eval_args
)
print(f"{args.checkpoint_name} Loss: {result}")
results.append(
    {
        "type": args.checkpoint_name,
        "compression_type": None,
        "factor": 0,
        "result": result
    }
)

for downsample_factor in all_downsample_factors:
    ## downsample_factor == 4 means 4x fewer parameters in topo layers
    model, tokenizer = get_checkpoint(checkpoints_map[args.checkpoint_name], device=device)
    for layer_name in topo_layer_names:
        downsampled_layer = DownsampledLinear(
            linear_layer=get_module_by_name(module=model, name=layer_name),
            factor_h=downsample_factor ** 0.5,
            factor_w=downsample_factor ** 0.5,
            device=device
        )
        setattr_pytorch_model(model, layer_name, downsampled_layer)

    result = eval_function(
        model=model,
        **eval_args
    )
    print(f"[Downsampled {downsample_factor}x] {args.checkpoint_name} Loss: {result}")
    results.append(
        {
            "type": args.checkpoint_name,
            "compression_type": "downsampling",
            "factor": downsample_factor,
            "result": result
        }
    )

for downsample_factor in all_downsample_factors:
    fraction_of_masked_weights = (100 - (100/downsample_factor))/100
    model, tokenizer = get_checkpoint(checkpoints_map[args.checkpoint_name], device=device)

    model = apply_l1_sparsity_to_model(model=model, fraction_of_masked_weights=fraction_of_masked_weights, layer_names=topo_layer_names)
    result = eval_function(
        model=model,
        **eval_args
    )

    print(f"[L1 sparse fraction: {fraction_of_masked_weights}] {args.checkpoint_name} Loss: {result}")
    results.append(
        {
            "type": args.checkpoint_name,
            "compression_type": "l1",
            "factor": downsample_factor,
            "result": result
        }
    )

if args.dataset == "openwebtext":
    dict_to_json(
        dictionary=results,
        filename=os.path.join(
            "results",
            f"{args.checkpoint_name}.json"
        )
    )
else:
    dict_to_json(
        dictionary=results,
        filename=os.path.join(
            "results_bookcorpus",
            f"{args.checkpoint_name}.json"
        )
    )
