from camelidae.configuration_moe import LlamaConfig
from camelidae.modeling_moe import LlamaForCausalLM as LoraMoeModel
from camelidae.modeling_moe_gate import LlamaForCausalLM as GateModel
from transformers import AutoTokenizer

model_config = LlamaConfig.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")
model_config.num_local_experts = 4
model = LoraMoeModel.from_pretrained(
    "/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf",
    config=model_config,
)
tokenizer = AutoTokenizer.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf")

# ckpts = ["output/sheared-sft-steps-500k/moe-gate-1/checkpoint-64", "output/sheared-sft-steps-500k/moe-gate-2/checkpoint-8", "output/sheared-sft-steps-500k/moe-gate-3/checkpoint-24", "output/sheared-sft-steps-500k/moe-gate-4/checkpoint-8", "output/sheared-sft-steps-500k/moe-gate-5/checkpoint-16", "output/sheared-sft-steps-500k/moe-gate-6/checkpoint-24", "output/sheared-sft-steps-500k/moe-gate-7/checkpoint-24", "output/sheared-sft-steps-500k/moe-gate-8/checkpoint-16"]
ckpts = ["output/sheared-steps/moe-gate-1/checkpoint-28", "output/sheared-steps/moe-gate-2/checkpoint-44", "output/sheared-steps/moe-gate-3/checkpoint-28", "output/sheared-steps/moe-gate-4/checkpoint-48"]


gate_models = [GateModel.from_pretrained(ckpts[i]) for i in range(len(ckpts))]

for name, param in model.named_parameters():
    if 'gate.weight' in name:
        for i in range(len(ckpts)):
            gate_column = gate_models[i].state_dict()[name].data.squeeze(0)
            param.data[i][:] = gate_column
    elif 'experts' in name:
        import re
        pattern = r'model\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(\w+)\.weight'
        matches = re.match(pattern, name)
        layer_number = matches.group(1)
        expert_num = matches.group(2)
        layer_type = matches.group(3)
        new_name = f"model.layers.{layer_number}.block_sparse_moe.expert.{layer_type}.weight"
        param.data = gate_models[int(expert_num)].state_dict()[new_name].data
    else:
        param.data = gate_models[0].state_dict()[name].data

model.save_pretrained("output/sheared-moe-router")
tokenizer.save_pretrained("output/sheared-moe-router")