from camelidae.configuration_loramoe import LlamaConfig
from camelidae.modeling_loramoe import LlamaForCausalLM as LoraMoeModel
from camelidae.modeling_gate import LlamaForCausalLM as GateModel
from transformers import AutoTokenizer

model_config = LlamaConfig.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf")
model = LoraMoeModel.from_pretrained(
    "/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf",
    config=model_config,
)
tokenizer = AutoTokenizer.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf")

ckpts = ["output/lora-moe-gate/lora-gate-1/checkpoint-24", "output/lora-moe-gate/lora-gate-2/checkpoint-24", "output/lora-moe-gate/lora-gate-3/checkpoint-24", "output/lora-moe-gate/lora-gate-4/checkpoint-20", "output/lora-moe-gate/lora-gate-5/checkpoint-8", "output/lora-moe-gate/lora-gate-6/checkpoint-12", "output/lora-moe-gate/lora-gate-7/checkpoint-16", "output/lora-moe-gate/lora-gate-8/checkpoint-16"]
# ckpts = ["output/lora-gate-488/checkpoint-28", "output/lora-gate-3904/checkpoint-28"]

gate_models = [GateModel.from_pretrained(ckpts[i]) for i in range(len(ckpts))]

for name, param in model.named_parameters():
    if 'gate.weight' in name:
        # import pdb
        # pdb.set_trace()
        for i in range(len(ckpts)):
            gate_column = gate_models[i].state_dict()[name].data.squeeze(0)
            param.data[i][:] = gate_column
    elif 'experts' in name:
        import re
        pattern = r'model\.layers\.(\d+)\.experts\.(\d+)\.(\w+)\.(\w+)\.weight'
        matches = re.match(pattern, name)
        layer_number = matches.group(1)
        expert_num = matches.group(2)
        layer_type = matches.group(3)
        variable_type = matches.group(4)
        new_name = f"model.layers.{layer_number}.expert.{layer_type}.{variable_type}.weight"
        param.data = gate_models[int(expert_num)].state_dict()[new_name].data

model.save_pretrained("output/lora-moe-router")
tokenizer.save_pretrained("output/lora-moe-router")