from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)
import os

nesim_configs_dir = "./nesim_configs"
os.system(f"mkdir -p {nesim_configs_dir}")
os.system(f"rm {nesim_configs_dir}/*.json")

"""
page number 10, Table 5 caption of the LoRA paper[1] mentions:
"Adapting both Wq and Wv gives the best performance overall"

Wq and Wv refer to the weights corresponding to the weights of the query and the value
layers in the attention module.

We see in the gpt neo codebase[2] that the inputs of q_proj and v_proj are basically the
hidden states i.e the outputs of the previous block or the input embeddings in case of
the first layer.

Hence I have decided to apply nesim loss to ever "mlp.c_proj" module. So that the
inputs received by q_proj and v_proj have low effective dimensionality.

For reference, here's a printout of a single block in gpt neo 125m
GPTNeoBlock(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GPTNeoAttention(
        (attention): GPTNeoSelfAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear8bitLt(in_features=768, out_features=768, bias=False)
        (v_proj): Linear8bitLt(in_features=768, out_features=768, bias=False)
        (q_proj): Linear8bitLt(in_features=768, out_features=768, bias=False)
        (out_proj): Linear8bitLt(in_features=768, out_features=768, bias=True)
        )
    )
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPTNeoMLP(
        (c_fc): Linear8bitLt(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear8bitLt(in_features=3072, out_features=768, bias=True)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.0, inplace=False)
    )
)

References:
[1] - XXXX
[2] - XXXX
"""

layers_to_apply_nesim_upon = {
    "all_layers_c_fc": [f"transformer.h.{i}.mlp.c_fc" for i in range(12)],
}

possible_shrink_factors = [
    # [5.],
    # [7.],
    [9.0]
]

possible_scales = [
    None,
    # 1,
    # 5,
    # 10,
    50,
]

os.system(f"rm {nesim_configs_dir}/*.json")
num_nesim_configs = 0
for name, layer_names in layers_to_apply_nesim_upon.items():
    for shrink_factor in possible_shrink_factors:
        for scale in possible_scales:

            layer_wise_configs = []

            for single_layer_name in layer_names:
                layer_wise_configs.extend(
                    [
                        ## scale = None -> just watch the layer's loss, do not backprop
                        NeighbourhoodCosineSimilarity(
                            layer_name=single_layer_name, scale=None
                        ),
                        LaplacianPyramid(
                            layer_name=single_layer_name,
                            scale=scale,
                            shrink_factor=shrink_factor,
                        ),
                    ]
                )
            nesim_config = NesimConfig(layer_wise_configs=layer_wise_configs)

            if scale is None:
                filename = (
                    f"baseline_shrink_factor_{shrink_factor}_layer_names_{name}.json"
                )
            else:
                filename = f"scale_{scale}_shrink_factor_{shrink_factor}_layer_names_{name}.json"

            nesim_config.save_json(os.path.join(nesim_configs_dir, filename))
            num_nesim_configs += 1
print(f"Saved {num_nesim_configs} nesim configs")
