

from modeling_llama2_noaffine_inner_ema_step_obs_stem_max3_3_actha2all_4000nas2_out_simple import LlamaForCausalLM

from inverse_llama2 import LlamaForCausalLM as slm
import torch
from transformers import AutoConfig

model = LlamaForCausalLM.from_pretrained("path1", attn_implementation="flash_attention_2")
config = AutoConfig.from_pretrained('path2')
student_model = slm(config=config)

stem_mask = model.lm_head.mask
def shape(model):
   
    stem_mask = model.lm_head.mask
    for i, layer in enumerate(model.model.layers):
        student_i_layer = student_model.model.layers[i]
        student_params = dict(student_i_layer.named_parameters())

        mask_att = layer.self_attn.o_proj.head_mask
        mask_ffn = layer.mlp.down_proj.mask
        
        print(f"mask_att: {mask_att.sum()}, mask_ffn: {mask_ffn.sum()}, stem_mask: {stem_mask.sum()}")

        # breakpoint()

        mask_att_expanded = mask_att.unsqueeze(1)
        mask_att_expanded = mask_att_expanded.repeat(1, 64)
        mask_att_expanded = mask_att_expanded.view(-1)

        for name, param in layer.named_parameters():
            if name in student_params:
                print("pre shape ", name, param.shape)
                if "self_attn.q_proj" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[mask_att_expanded, :].clone().contiguous()[:, stem_mask].contiguous())
                elif "self_attn.v_proj" in name or "self_attn.k_proj" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[mask_att_expanded, :].clone().contiguous()[:, stem_mask].contiguous())
                        
                elif "self_attn.o_proj" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[:, mask_att_expanded].clone().contiguous()[stem_mask, :].contiguous())
                elif "mlp.down_proj" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[:, mask_ffn].clone().contiguous()[stem_mask, :].contiguous())
                elif "mlp.gate_proj" in name or "mlp.up_proj" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[mask_ffn, :].clone().contiguous()[:, stem_mask].contiguous())
                elif "layernorm" in name:
                    with torch.no_grad():
                        student_params[name].copy_(param[stem_mask].clone())
                else:
                    print("xxx：", name)
                print("student_params shape ", name, student_params[name].shape)

    student_params = dict(student_model.named_parameters())
    teacher_params = dict(model.named_parameters())
    
    with torch.no_grad():
        
        student_params["model.norm.weight"].copy_(teacher_params["model.norm.weight"][stem_mask].clone().contiguous())
        student_params["model.embed_tokens.weight"].copy_(teacher_params["model.embed_tokens.weight"][:,stem_mask].clone().contiguous())

    for name, param in student_model.named_parameters():
            with torch.no_grad():
                print(f"S Parameter: {name}, Shape: {param.shape}")

                if "kv_idx" in name:
                    print(f"kv_idx: {name}, {param}")
    

shape(model)
save_directory = "path"
student_model.save_pretrained(save_directory)
