from PIL import Image
import numpy as np
import os
import torch
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.experiments.gpt_neo_125m import get_checkpoint
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from nesim.grid.two_dimensional import BaseGrid2dLinear
from nesim.utils.grid_size import find_rectangle_dimensions
from einops import rearrange
import cv2
import matplotlib.pyplot as plt

topo_scale = 50
layer_index = 10
category = "technology"
downsample_factor = 9

softmax_maps_dir = "../category_maps/assets"
layer_name = f"transformer.h.{layer_index}.mlp.c_fc"
checkpoints_dir = "../../training/gpt_neo_125m/checkpoints"
checkpoint_name = f"topo_{topo_scale}"


checkpoint_filename = get_checkpoint_path_gpt_neo_125m(
    checkpoints_dir=checkpoints_dir,
    topo_scale=topo_scale,
    global_step=10500
)
model, tokenizer = get_checkpoint(checkpoint_filename=checkpoint_filename)
model = apply_l1_sparsity_to_model(
    model=model,
    fraction_of_masked_weights=(100 - (100/downsample_factor))/100,
    layer_names=[layer_name]
)
sparse_layer = get_module_by_name(module=model, name= layer_name)
size = find_rectangle_dimensions(sparse_layer.weight.data.shape[0])
grid = BaseGrid2dLinear(
    linear_layer=sparse_layer,
    height=size.height,
    width=size.width,
    device="cpu"
)

grid_reshaped = rearrange(
    grid.grid,
    "h w e -> (h w) e"
)
l1_mask = 1 - (torch.mean(grid_reshaped.abs(), dim = -1) < 3e-3).float().numpy()
l1_mask = rearrange(
    l1_mask,
    "(h w) -> h w",
    h = size.height,
    w = size.width
)
# raise AssertionError(l1_mask)

map_folder = os.path.join(
    softmax_maps_dir,
    checkpoint_name,
    layer_name,
    "dprime"
)

image_filename = os.path.join(
    map_folder,
    f"{category}.png"
)
array_filename = os.path.join(
    map_folder,
    f"{category}.npy"
)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize = (15 , 8))

image = np.array(Image.open(image_filename))
l1_mask = cv2.resize(l1_mask, (image.shape[1], image.shape[0]), cv2.INTER_NEAREST)
downsampled_image = cv2.resize(image, (int(image.shape[1]//downsample_factor**0.5), int(image.shape[0]//downsample_factor**0.5)), cv2.INTER_NEAREST)

Image.fromarray(image).save("original.png")
Image.fromarray(image[:,:,:3] * l1_mask[...,None].astype(np.uint8)).save("l1.png")
Image.fromarray(downsampled_image).save("downsampled.png")
