from camelidae.configuration_loramoe import LlamaConfig
from camelidae.modeling_loramoe import LlamaForCausalLM as LoraMoeModel
from transformers import AutoTokenizer
import random
from random import sample
import torch.nn.functional as F
import torch
import numpy as np
import os
from numpy.linalg import norm
import re
from safetensors.torch import save_file

def set_seed(seed=42):
    random.seed(seed)

def copy_param(base_model, target_model):
    base_model_param = base_model.state_dict()
    for name, param in target_model.named_parameters():
        if 'experts' in name:
            if name in base_model_param.keys():
                param.data = base_model_param[name].data
        elif 'gate.weight' in name:
            gate_column = base_model_param[name].data
            param.data[:len(gate_column)][:] = gate_column
        else:
            param.data = base_model_param[name].data
    

def genetic_algo(base_model_name_or_path, target_model_name_or_path, save_path, num_local_experts=8, num_target_experts=16):
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

    # load base model
    base_model_config = LlamaConfig.from_pretrained(base_model_name_or_path)
    base_model = LoraMoeModel.from_pretrained(base_model_name_or_path, config=base_model_config)
    
    # create target model from pretrained model
    target_model_config = LlamaConfig.from_pretrained(target_model_name_or_path)
    target_model_config.num_local_experts = num_target_experts
    target_model = LoraMoeModel.from_pretrained(target_model_name_or_path, config=target_model_config)
    
    copy_param(base_model, target_model)
    
    for step in range(num_target_experts - num_local_experts):
        tensor_dict = {}
        def calculate_expert_weights(target_model_param, layer, num_local_experts, step):
            expert_weights = []
            for expert in range(num_local_experts + step):
                gate_name_A = f"model.layers.{layer}.experts.{expert}.gate_lora.A.weight"
                gate_name_B = f"model.layers.{layer}.experts.{expert}.gate_lora.B.weight"
                down_name_A = f"model.layers.{layer}.experts.{expert}.down_lora.A.weight"
                down_name_B = f"model.layers.{layer}.experts.{expert}.down_lora.B.weight"
                up_name_A = f"model.layers.{layer}.experts.{expert}.up_lora.A.weight"
                up_name_B = f"model.layers.{layer}.experts.{expert}.up_lora.B.weight"

                down_weight = torch.matmul(target_model_param[down_name_B], target_model_param[down_name_A])
                gate_weight = torch.matmul(target_model_param[gate_name_B], target_model_param[gate_name_A])
                up_weight = torch.matmul(target_model_param[up_name_B], target_model_param[up_name_A])

                expert_weights.append(torch.matmul(down_weight, gate_weight * up_weight))

            return expert_weights

        def calculate_cosine_similarity(expert_weights):
            num_experts = len(expert_weights)
            cos_sim = np.zeros((num_experts, num_experts))
            for i in range(num_experts):
                for j in range(num_experts):
                    cos_sim[i][j] = torch.cosine_similarity(expert_weights[i], expert_weights[j], dim=0).sum() / len(expert_weights[i])
            return cos_sim

        def get_two_different_random_integers(start, end):
            # 生成第一个随机整数
            random_int1 = random.randint(start, end)
            # 生成第二个随机整数，直到它与第一个整数不同
            random_int2 = random.randint(start, end)
            while random_int2 == random_int1:
                random_int2 = random.randint(start, end)

            return random_int1, random_int2
    
        target_model_param = target_model.state_dict()
        select_ids = []
        
        # 随机选两个
        # for _ in range(int(target_model_config.num_hidden_layers)):
        #     id1, id2 = get_two_different_random_integers(0, num_local_experts - 1 + step)
        #     select_ids.append((id1, id2))
        
        # 根据maxdiff选两个
        for layer in range(int(target_model_config.num_hidden_layers)):
            expert_weights = calculate_expert_weights(target_model_param, layer, num_local_experts, step)
            cos_sim = calculate_cosine_similarity(expert_weights)
            r, c = np.where(cos_sim == np.min(cos_sim))
            select_ids.append((r[0], c[0]))

        # 进行merge
        pattern_experts = re.compile(r'model\.layers\.(\d+)\.experts\.(\d+)\.(\w+)\.(\w+)\.weight')
        pattern_gate = re.compile(r'model\.layers\.(\d+)\.gate\.weight')
        select_ids = np.array(select_ids)
        
        for name, param in target_model.named_parameters():
            if "experts" in name:
                matches = pattern_experts.match(name)
                layer_num = int(matches.group(1))
                expert_num = int(matches.group(2))
                layer_type = matches.group(3)
                variable_type = matches.group(4)
                if expert_num == num_local_experts + step:
                    id1, id2 = select_ids[layer_num]
                    new_name_id1 = f"model.layers.{layer_num}.experts.{id1}.{layer_type}.{variable_type}.weight"
                    new_name_id2 = f"model.layers.{layer_num}.experts.{id2}.{layer_type}.{variable_type}.weight"
                    random_weight = np.random.dirichlet(np.ones(2), size=1)
                    weight_sum = random_weight[0][0] * target_model_param[new_name_id1].data + random_weight[0][1] * target_model_param[new_name_id2].data 
                    # param.data = weight_sum
                    new_name = f'base_model.model.model.layers.{layer_num}.mlp.{layer_type[:-5]}_proj.lora_{variable_type}.weight'
                    tensor_dict[new_name] = weight_sum
        os.makedirs(save_path + f"/expand_expert_{step}")
        save_file(tensor_dict, save_path + f"/expand_expert_{step}/adapter_model.safetensors")
    
if __name__ == "__main__":
    genetic_algo(
        base_model_name_or_path="/root/paddlejob/workspace/env_run/huitingfeng/MoE/output/lora-moe-lora-ckpt",
        target_model_name_or_path="/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2-7b-hf",
        save_path="/root/paddlejob/workspace/env_run/huitingfeng/MoE/output/lora-500k-steps-new"
    )