from camelidae.configuration_moe import LlamaConfig
from camelidae.modeling_moe import LlamaForCausalLM as LoraMoeModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from random import sample
import torch.nn.functional as F
import torch
import numpy as np
import os
from numpy.linalg import norm
import re
from safetensors.torch import save_file

def set_seed(seed=42):
    random.seed(seed)

def copy_param(base_model, target_model):
    base_model_param = base_model.state_dict()
    for name, param in target_model.named_parameters():
        if 'experts' in name:
            if name in base_model_param.keys():
                param.data = base_model_param[name].data
        elif 'gate.weight' in name:
            gate_column = base_model_param[name].data
            param.data[:len(gate_column)][:] = gate_column
        else:
            param.data = base_model_param[name].data
    

def genetic_algo(base_model_name_or_path, target_model_name_or_path, save_path, num_local_experts=4, num_target_experts=8):
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

    # load base model
    base_model_config = LlamaConfig.from_pretrained(base_model_name_or_path)
    base_model = LoraMoeModel.from_pretrained(base_model_name_or_path, config=base_model_config)
    
    # create target model from pretrained model
    target_model_config = LlamaConfig.from_pretrained(target_model_name_or_path)
    target_model_config.num_local_experts = num_target_experts
    target_model = LoraMoeModel.from_pretrained(target_model_name_or_path, config=target_model_config)
    
    copy_param(base_model, target_model)
    
    for step in range(num_target_experts - num_local_experts):
        tensor_dict = {}
        def calculate_expert_weights(target_model_param, layer, num_local_experts, step):
            expert_weights = []
            for expert in range(num_local_experts + step):
                gate_name = f"model.layers.{layer}.block_sparse_moe.experts.{expert}.gate_proj.weight"
                down_name = f"model.layers.{layer}.block_sparse_moe.experts.{expert}.down_proj.weight"
                up_name = f"model.layers.{layer}.block_sparse_moe.experts.{expert}.up_proj.weight"

                down_weight = target_model_param[down_name]
                gate_weight = target_model_param[gate_name]
                up_weight = target_model_param[up_name]

                expert_weights.append(torch.matmul(down_weight, gate_weight * up_weight))

            return expert_weights

        def calculate_cosine_similarity(expert_weights):
            num_experts = len(expert_weights)
            cos_sim = np.zeros((num_experts, num_experts))
            for i in range(num_experts):
                for j in range(num_experts):
                    cos_sim[i][j] = torch.cosine_similarity(expert_weights[i], expert_weights[j], dim=0).sum() / len(expert_weights[i])
            return cos_sim

        def get_two_different_random_integers(start, end):
            # 生成第一个随机整数
            random_int1 = random.randint(start, end)
            # 生成第二个随机整数，直到它与第一个整数不同
            random_int2 = random.randint(start, end)
            while random_int2 == random_int1:
                random_int2 = random.randint(start, end)

            return random_int1, random_int2
    
        target_model_param = target_model.state_dict()
        select_ids = []
        
        # 随机选两个
        # for _ in range(int(target_model_config.num_hidden_layers)):
        #     id1, id2 = get_two_different_random_integers(0, num_local_experts - 1 + step)
        #     select_ids.append((id1, id2))
        
        # 根据maxdiff选两个
        for layer in range(int(target_model_config.num_hidden_layers)):
            expert_weights = calculate_expert_weights(target_model_param, layer, num_local_experts, step)
            cos_sim = calculate_cosine_similarity(expert_weights)
            r, c = np.where(cos_sim == np.min(cos_sim))
            select_ids.append((r[0], c[0]))

        # 进行merge
        pattern_experts = re.compile(r'model\.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(\w+)\.weight')
        pattern_gate = re.compile(r'model\.layers\.(\d+)\.block_sparse_moe\.gate\.weight')
        select_ids = np.array(select_ids)
        
        pretrain_param = AutoModelForCausalLM.from_pretrained(target_model_name_or_path).state_dict()
        
        for name, param in target_model.named_parameters():
            if "experts" in name:
                matches = pattern_experts.match(name)
                layer_num = int(matches.group(1))
                expert_num = int(matches.group(2))
                layer_type = matches.group(3)
                if expert_num == num_local_experts + step:
                    id1, id2 = select_ids[layer_num]
                    new_name_id1 = f"model.layers.{layer_num}.block_sparse_moe.experts.{id1}.{layer_type}.weight"
                    new_name_id2 = f"model.layers.{layer_num}.block_sparse_moe.experts.{id2}.{layer_type}.weight"
                    weight_sum = 0.5 * target_model_param[new_name_id1].data + 0.5 * target_model_param[new_name_id2].data 
                    new_name = f'model.layers.{layer_num}.mlp.{layer_type}.weight'
                    # DARE
                    mask = np.random.choice([0, 1], size=weight_sum.shape, p=[0.5, 0.5])
                    diff = weight_sum - pretrain_param[new_name]
                    diff = diff * mask * 2
                    weight_sum = diff + pretrain_param[new_name]
                    param.data = weight_sum
            elif 'gate.weight' in name:
                matches = pattern_gate.match(name)
                layer_num = int(matches.group(1))
                id1, id2 = select_ids[layer_num]
                weight_sum = (0.5 * target_model_param[name].data[id1] + 0.5 * target_model_param[name].data[id2]).squeeze(0)
                param.data[num_local_experts + step][:] = weight_sum
        target_model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
    
if __name__ == "__main__":
    genetic_algo(
        base_model_name_or_path="/root/paddlejob/workspace/env_run/huitingfeng/MoE/output/sheared-moe-ckpt-500k",
        target_model_name_or_path="/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf",
        save_path="/root/paddlejob/workspace/env_run/huitingfeng/MoE/output/sheared-moe-expand-500k"
    )