from camelidae.configuration_moe import LlamaConfig
from camelidae.modeling_moe import LlamaForCausalLM as LoraMoeModel
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from camelidae.modeling_moe_gate import LlamaForCausalLM as GateModel

model_config = LlamaConfig.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")
model = LoraMoeModel.from_pretrained(
    "/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf",
    config=model_config,
)
tokenizer = AutoTokenizer.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")

ckpts = ["output/sheared-sft-steps-500k/moe-gate-1/checkpoint-48", "output/sheared-sft-steps-500k/moe-gate-2/checkpoint-48", "output/sheared-sft-steps-500k/moe-gate-3/checkpoint-48", "output/sheared-sft-steps-500k/moe-gate-4/checkpoint-20"]

lora_models = [GateModel.from_pretrained(ckpts[i]).state_dict() for i in range(len(ckpts))]

for name, param in model.named_parameters():
    if 'experts' in name:
        import re
        pattern = r'model\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(\w+)\.weight'
        matches = re.match(pattern, name)
        layer_number = matches.group(1)
        expert_num = matches.group(2)
        layer_type = matches.group(3)
        new_name = f"model.layers.{layer_number}.block_sparse_moe.expert.{layer_type}.weight"
        param.data = lora_models[int(expert_num)][new_name].data
    elif 'gate.weight' in name:
        for i in range(len(ckpts)):
            gate_column = lora_models[i][name].data.squeeze(0)
            param.data[i][:] = gate_column
    else:
        param.data = sum([lora_models[i][name] for i in range(len(ckpts))]) / len(ckpts)
    

model.save_pretrained("output/sheared-moe-ckpt-500k")
tokenizer.save_pretrained("output/sheared-moe-ckpt-500k")