import torch
from utils import sllinear

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
def get_weighted_layers_llm(model, W_list, qk_chain_list, chain_list, args):
    def add_W(module):
        W_list.append(module.weight)
        module.LAYER_INDEX = len(W_list) - 1

    items = model._modules.items()

    for layer_name, p in items:
        if layer_name == 'lm_head':
            continue
        else:
            assert not isinstance(p, torch.nn.Linear) and not isinstance(p, sllinear)
            if layer_name in ["self_attn"]:
                add_W(p.k_proj)
                add_W(p.v_proj)
                add_W(p.q_proj)
                add_W(p.o_proj)
                qk_chain_list.append([p.k_proj.LAYER_INDEX, p.q_proj.LAYER_INDEX])
                chain_list.append([p.v_proj.LAYER_INDEX, p.o_proj.LAYER_INDEX])
            elif layer_name == "mlp":
                add_W(p.gate_proj)
                add_W(p.up_proj)
                add_W(p.down_proj)
                qk_chain_list.append([p.gate_proj.LAYER_INDEX, p.up_proj.LAYER_INDEX])
                chain_list.append([p.gate_proj.LAYER_INDEX, p.down_proj.LAYER_INDEX])
                chain_list.append([p.up_proj.LAYER_INDEX, p.down_proj.LAYER_INDEX])
            else:
                get_weighted_layers_llm(p, W_list=W_list, qk_chain_list=qk_chain_list, chain_list=chain_list, args=args)


def get_W(model, args):
    W_list = []
    qk_chain_list = []
    chain_list = []

    get_weighted_layers_llm(model, W_list=W_list, qk_chain_list=qk_chain_list, chain_list=chain_list, args=args)


    return W_list, chain_list, qk_chain_list



