
from ..base import BasePatch
from .base import BaseBTMoEHFModel
from tqdm import tqdm


# Patch  functions
class DeepSeekMoEPatch(BasePatch):
    # These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers
    @classmethod
    def get_linear_tags(cls):
        return [
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
            "self_attn.o_proj",
            "mlp.gate_proj",
            "mlp.up_proj",
            "mlp.down_proj",
            "mlp.experts.gate_proj",
            "mlp.experts.up_proj",
            "mlp.experts.down_proj",
            "mlp.shared_experts.gate_proj",
            "mlp.shared_experts.up_proj",
            "mlp.shared_experts.down_proj",
        ]

    @classmethod
    def get_leaf_linear_tags(cls):
        leaf_tags = []
        num_layers = 28

        for i in range(num_layers):
            leaf_tags.append(f"model.layers.{i}.self_attn.q_proj")
            leaf_tags.append(f"model.layers.{i}.self_attn.k_proj")
            leaf_tags.append(f"model.layers.{i}.self_attn.v_proj")
            leaf_tags.append(f"model.layers.{i}.self_attn.o_proj")
            
            if i == 0:
                leaf_tags.append(f"model.layers.{i}.mlp.gate_proj")
                leaf_tags.append(f"model.layers.{i}.mlp.up_proj")
                leaf_tags.append(f"model.layers.{i}.mlp.down_proj")
            else:
                n_experts = 64
                for k in range(n_experts):
                    leaf_tags.append(f"model.layers.{i}.mlp.experts.{k}.gate_proj")
                    leaf_tags.append(f"model.layers.{i}.mlp.experts.{k}.up_proj")
                    leaf_tags.append(f"model.layers.{i}.mlp.experts.{k}.down_proj")
                
                leaf_tags.append(f"model.layers.{i}.mlp.shared_experts.gate_proj")
                leaf_tags.append(f"model.layers.{i}.mlp.shared_experts.up_proj")
                leaf_tags.append(f"model.layers.{i}.mlp.shared_experts.down_proj")
        
        return leaf_tags

    @classmethod
    def patch_nonlinearlayers(cls, model, patch_fct, verbose=True):
        base_model = model.model
        model.lm_head = patch_fct(model.lm_head)  ###
        base_model.embed_tokens = patch_fct(base_model.embed_tokens)
        base_model.norm = patch_fct(base_model.norm)

        layers = base_model.layers
        for i in tqdm(range(len(base_model.layers)), disable=not verbose):
            layers[i].self_attn.rotary_emb = patch_fct(layers[i].self_attn.rotary_emb)
            layers[i].input_layernorm = patch_fct(layers[i].input_layernorm)
            layers[i].post_attention_layernorm = patch_fct(
                layers[i].post_attention_layernorm
            )

            if i == 0:
                layers[0].mlp.act_fn = patch_fct(layers[0].mlp.act_fn)
            else:
                n_experts = len(layers[i].mlp.experts)
                for k in range(n_experts):
                    layers[i].mlp.experts[k].act_fn = patch_fct(
                        layers[i].mlp.experts[k].act_fn
                    )
                layers[i].mlp.shared_experts.act_fn = patch_fct(
                    layers[i].mlp.shared_experts.act_fn
                )
                layers[i].mlp.gate = patch_fct(
                    layers[i].mlp.gate
                )  # Keep MOE gate as fp16 because it's small


    @classmethod
    def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True):
        base_model = model.model
        layers = base_model.layers
        for i in tqdm(range(len(layers)), disable=not verbose):
            layers[i].self_attn.q_proj = patch_fct(
                layers[i].self_attn.q_proj, patch_params["self_attn.q_proj"]
            )
            layers[i].self_attn.k_proj = patch_fct(
                layers[i].self_attn.k_proj, patch_params["self_attn.k_proj"]
            )
            layers[i].self_attn.v_proj = patch_fct(
                layers[i].self_attn.v_proj, patch_params["self_attn.v_proj"]
            )
            layers[i].self_attn.o_proj = patch_fct(
                layers[i].self_attn.o_proj, patch_params["self_attn.o_proj"]
            )

            if i == 0:
                layers[i].mlp.gate_proj = patch_fct(
                    layers[i].mlp.gate_proj,
                    patch_params["mlp.gate_proj"],
                )
                layers[i].mlp.up_proj = patch_fct(
                    layers[i].mlp.up_proj,
                    patch_params["mlp.up_proj"],
                )
                layers[i].mlp.down_proj = patch_fct(
                    layers[i].mlp.down_proj,
                    patch_params["mlp.down_proj"],
                )
            else:
                n_experts = len(layers[i].mlp.experts)
                
                for k in range(n_experts):
                    layers[i].mlp.experts[k].gate_proj = patch_fct(
                        layers[i].mlp.experts[k].gate_proj,
                        patch_params["mlp.experts.gate_proj"],
                    )
                    layers[i].mlp.experts[k].up_proj = patch_fct(
                        layers[i].mlp.experts[k].up_proj,
                        patch_params["mlp.experts.up_proj"],
                    )
                    layers[i].mlp.experts[k].down_proj = patch_fct(
                        layers[i].mlp.experts[k].down_proj,
                        patch_params["mlp.experts.down_proj"],
                    )
                layers[i].mlp.shared_experts.gate_proj = patch_fct(
                    layers[i].mlp.shared_experts.gate_proj,
                    patch_params["mlp.shared_experts.gate_proj"],
                )
                layers[i].mlp.shared_experts.up_proj = patch_fct(
                    layers[i].mlp.shared_experts.up_proj,
                    patch_params["mlp.shared_experts.up_proj"],
                )
                layers[i].mlp.shared_experts.down_proj = patch_fct(
                    layers[i].mlp.shared_experts.down_proj,
                    patch_params["mlp.shared_experts.down_proj"],
                )

    @classmethod
    def patch_linearlayers_mixbit(cls, model, patch_fct, patch_params, verbose=True):
        """
        MixBit quantization for DeepSeek
        """
        base_model = model.model
        layers = base_model.layers
        for i in range(len(layers)):

            q_proj_key = f"model.layers.{i}.self_attn.q_proj"
            k_proj_key = f"model.layers.{i}.self_attn.k_proj"
            v_proj_key = f"model.layers.{i}.self_attn.v_proj"
            o_proj_key = f"model.layers.{i}.self_attn.o_proj"
            
            q_proj_config = patch_params.get(q_proj_key, patch_params.get("self_attn.q_proj"))
            k_proj_config = patch_params.get(k_proj_key, patch_params.get("self_attn.k_proj"))
            v_proj_config = patch_params.get(v_proj_key, patch_params.get("self_attn.v_proj"))
            o_proj_config = patch_params.get(o_proj_key, patch_params.get("self_attn.o_proj"))
            
            if q_proj_config is not None:
                layers[i].self_attn.q_proj = patch_fct(layers[i].self_attn.q_proj, q_proj_config)
            if k_proj_config is not None:
                layers[i].self_attn.k_proj = patch_fct(layers[i].self_attn.k_proj, k_proj_config)
            if v_proj_config is not None:
                layers[i].self_attn.v_proj = patch_fct(layers[i].self_attn.v_proj, v_proj_config)
            if o_proj_config is not None:
                layers[i].self_attn.o_proj = patch_fct(layers[i].self_attn.o_proj, o_proj_config)

            if i == 0:
                gate_proj_key = f"model.layers.{i}.mlp.gate_proj"
                up_proj_key = f"model.layers.{i}.mlp.up_proj"
                down_proj_key = f"model.layers.{i}.mlp.down_proj"
                
                gate_proj_config = patch_params.get(gate_proj_key, patch_params.get("mlp.gate_proj"))
                up_proj_config = patch_params.get(up_proj_key, patch_params.get("mlp.up_proj"))
                down_proj_config = patch_params.get(down_proj_key, patch_params.get("mlp.down_proj"))
                
                if gate_proj_config is not None:
                    layers[i].mlp.gate_proj = patch_fct(layers[i].mlp.gate_proj, gate_proj_config)
                if up_proj_config is not None:
                    layers[i].mlp.up_proj = patch_fct(layers[i].mlp.up_proj, up_proj_config)
                if down_proj_config is not None:
                    layers[i].mlp.down_proj = patch_fct(layers[i].mlp.down_proj, down_proj_config)
            else:
                n_experts = len(layers[i].mlp.experts)
                for k in range(n_experts):
                    gate_proj_key = f"model.layers.{i}.mlp.experts.{k}.gate_proj"
                    up_proj_key = f"model.layers.{i}.mlp.experts.{k}.up_proj"
                    down_proj_key = f"model.layers.{i}.mlp.experts.{k}.down_proj"
                    
                    gate_proj_config = patch_params.get(gate_proj_key, patch_params.get("mlp.experts.gate_proj"))
                    up_proj_config = patch_params.get(up_proj_key, patch_params.get("mlp.experts.up_proj"))
                    down_proj_config = patch_params.get(down_proj_key, patch_params.get("mlp.experts.down_proj"))
                    
                    if gate_proj_config is not None:
                        layers[i].mlp.experts[k].gate_proj = patch_fct(
                            layers[i].mlp.experts[k].gate_proj, gate_proj_config
                        )
                    if up_proj_config is not None:
                        layers[i].mlp.experts[k].up_proj = patch_fct(
                            layers[i].mlp.experts[k].up_proj, up_proj_config
                        )
                    if down_proj_config is not None:
                        layers[i].mlp.experts[k].down_proj = patch_fct(
                            layers[i].mlp.experts[k].down_proj, down_proj_config
                        )
                
                shared_gate_key = f"model.layers.{i}.mlp.shared_experts.gate_proj"
                shared_up_key = f"model.layers.{i}.mlp.shared_experts.up_proj"
                shared_down_key = f"model.layers.{i}.mlp.shared_experts.down_proj"
                
                shared_gate_config = patch_params.get(shared_gate_key, patch_params.get("mlp.shared_experts.gate_proj"))
                shared_up_config = patch_params.get(shared_up_key, patch_params.get("mlp.shared_experts.up_proj"))
                shared_down_config = patch_params.get(shared_down_key, patch_params.get("mlp.shared_experts.down_proj"))
                
                if shared_gate_config is not None:
                    layers[i].mlp.shared_experts.gate_proj = patch_fct(
                        layers[i].mlp.shared_experts.gate_proj, shared_gate_config
                    )
                if shared_up_config is not None:
                    layers[i].mlp.shared_experts.up_proj = patch_fct(
                        layers[i].mlp.shared_experts.up_proj, shared_up_config
                    )
                if shared_down_config is not None:
                    layers[i].mlp.shared_experts.down_proj = patch_fct(
                        layers[i].mlp.shared_experts.down_proj, shared_down_config
                    )

class DeepSeekMoEBTMoE(DeepSeekMoEPatch, BaseBTMoEHFModel):
    pass
