from camelidae.configuration_loramoe import LlamaConfig
from camelidae.modeling_loramoe import LlamaForCausalLM as LoraMoeModel
from safetensors.torch import load_file
from transformers import AutoTokenizer

model_config = LlamaConfig.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf")
model = LoraMoeModel.from_pretrained(
    "/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf",
    config=model_config,
)
tokenizer = AutoTokenizer.from_pretrained("/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf")

# ckpts = ["output/lora-500k-steps-new/checkpoint-4392", "output/lora-500k-steps-new/checkpoint-4880", "output/lora-500k-steps-new/checkpoint-5368", "output/lora-500k-steps-new/checkpoint-5856", "output/lora-500k-steps-new/checkpoint-6344", "output/lora-500k-steps-new/checkpoint-6832", "output/lora-500k-steps-new/checkpoint-7320", "output/lora-500k-steps-new/checkpoint-7808"]

ckpts = ["output/lora-200k-steps/checkpoint-1755", "output/lora-200k-steps/checkpoint-1950", "output/lora-200k-steps/checkpoint-2145", "output/lora-200k-steps/checkpoint-2340", "output/lora-200k-steps/checkpoint-2535", "output/lora-200k-steps/checkpoint-2730", "output/lora-200k-steps/checkpoint-2925", "output/lora-200k-steps/checkpoint-3120"]

lora_models = [load_file(ckpts[i] + '/adapter_model.safetensors') for i in range(len(ckpts))]

for name, param in model.named_parameters():
    if 'experts' in name:
        import re
        pattern = r'model\.layers\.(\d+)\.experts\.(\d+)\.(\w+)\.(\w+)\.weight'
        matches = re.match(pattern, name)
        layer_number = matches.group(1)
        expert_num = matches.group(2)
        layer_type = matches.group(3)
        variable_type = matches.group(4)
        new_name = f"base_model.model.model.layers.{layer_number}.mlp.{layer_type[:-5]}_proj.lora_{variable_type}.weight"
        param.data = lora_models[int(expert_num)][new_name].data

model.save_pretrained("output/lora-moe-lora-ckpt-200k")
tokenizer.save_pretrained("output/lora-moe-lora-ckpt-200k")