from camelidae.configuration_moe import LlamaConfig
from camelidae.modeling_moe import LlamaForCausalLM as LoraMoeModel
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM

model_config = LlamaConfig.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")
model_config.num_local_experts = 8
model = LoraMoeModel.from_pretrained(
    "/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf",
    config=model_config,
)
tokenizer = AutoTokenizer.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")

ckpt = "output/sheared-sft-steps-500k/checkpoint-7808"

lora_model = AutoModelForCausalLM.from_pretrained(ckpt).state_dict()

for name, param in model.named_parameters():
    if 'experts' in name:
        import re
        pattern = r'model\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(\w+)\.weight'
        matches = re.match(pattern, name)
        layer_number = matches.group(1)
        expert_num = matches.group(2)
        layer_type = matches.group(3)
        new_name = f"model.layers.{layer_number}.mlp.{layer_type}.weight"
        param.data = lora_model[new_name].data

model.save_pretrained("output/sheared-moe-upcycle-8e-500k", safe_serialization=False)
tokenizer.save_pretrained("output/sheared-moe-upcycle-8e-500k")