from .manar import Manar
from .contextualization import Contextualization

def append_manar_to_wav2vec(model, M=256, m=64, cwl=128):
    uncopied_params = []
    for layer in model.w2v_encoder.w2v_model.encoder.layers:
        old_attn = layer.self_attn
        manar = Manar(
            dim=old_attn.num_heads * old_attn.head_dim,
            num_heads=old_attn.num_heads,
            num_memory_cells=M,
            conceptual_representation_size=m,
            context_window_len=cwl,
            seperate_qkv=True,
            qkv_bias=True,
        )
        old_attn.manar = manar
        manar.q_proj = old_attn.q_proj
        manar.k_proj = old_attn.k_proj
        manar.v_proj = old_attn.v_proj
        manar.proj = old_attn.out_proj
        uncopied_params += manar.get_trainable_params()
    return model, uncopied_params


def append_manar_to_vit(model, knowledge_transfer=False, M=256, m=32, cwl=96):
    # Go over the layers of the network
    uncopied_params = []
    for block in model.blocks:
        new_attn = Manar(
            dim=block.attn.num_heads * block.attn.head_dim,
            num_heads=block.attn.num_heads,
            num_memory_cells=M,
            conceptual_representation_size=m,
            context_window_len=cwl,
        )
        if knowledge_transfer:
            new_attn.qkv = block.attn.qkv
            new_attn.q_norm = block.attn.q_norm
            new_attn.k_norm = block.attn.k_norm
            new_attn.proj = block.attn.proj
            uncopied_params += new_attn.get_trainable_params()
        block.attn = new_attn
    return model, uncopied_params
