from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)
import os

nesim_configs_dir = "./nesim_configs"
os.system(f"mkdir -p {nesim_configs_dir}")
os.system(f"rm {nesim_configs_dir}/*.json")

"""
GPT NEO 1.3B MODEL PRINT

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 2048)
    (wpe): Embedding(2048, 2048)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPTNeoBlock(
        (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
        )
        (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=2048, out_features=8192, bias=True)
          (c_proj): Linear(in_features=8192, out_features=2048, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=50257, bias=False)
)

References:
[1] - XXXX
[2] - XXXX
"""

layers_to_apply_nesim_upon = {
    "all_layers_c_fc": [f"transformer.h.{i}.mlp.c_fc" for i in range(24)],
}

possible_shrink_factors = [
    [9.0]
]

possible_scales = [
    None,
    # 1,
    # 5,
    # 10,
    50,
]

os.system(f"rm {nesim_configs_dir}/*.json")
num_nesim_configs = 0
for name, layer_names in layers_to_apply_nesim_upon.items():
    for shrink_factor in possible_shrink_factors:
        for scale in possible_scales:

            layer_wise_configs = []

            for single_layer_name in layer_names:
                layer_wise_configs.extend(
                    [
                        ## scale = None -> just watch the layer's loss, do not backprop
                        NeighbourhoodCosineSimilarity(
                            layer_name=single_layer_name, scale=None
                        ),
                        LaplacianPyramid(
                            layer_name=single_layer_name,
                            scale=scale,
                            shrink_factor=shrink_factor,
                        ),
                    ]
                )
            nesim_config = NesimConfig(layer_wise_configs=layer_wise_configs)

            if scale is None:
                filename = (
                    f"baseline_shrink_factor_{shrink_factor}_layer_names_{name}.json"
                )
            else:
                filename = f"scale_{scale}_shrink_factor_{shrink_factor}_layer_names_{name}.json"

            nesim_config.save_json(os.path.join(nesim_configs_dir, filename))
            num_nesim_configs += 1
print(f"Saved {num_nesim_configs} nesim configs")
