from .base_model import BaseModel
from llmc.utils.registry_factory import MODEL_REGISTRY


@MODEL_REGISTRY
class Mistral(BaseModel):
    def __init__(self, model_path, torch_dtype):
        super().__init__(model_path, torch_dtype)

    def find_blocks(self):
        self.blocks = self.model.model.layers

    def find_embed_layers(self):
        self.embed_tokens = self.model.model.embed_tokens

    def find_block_name(self):
        self.block_name_prefix = "model.layers"
        self.pairs = {"q_proj": "qkv", "o_proj": "out", "up_proj": "fc1"}

    def get_embed_layers(self):
        return [self.embed_tokens]

    def get_layers_except_blocks(self):
        return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

    def has_bias(self):
        return False

    def get_layernorms_in_block(self, block):
        return {
            "input_layernorm": block.input_layernorm,
            "post_attention_layernorm": block.post_attention_layernorm,
        }

    def get_subsets_in_block(self, block):
        return [
            {
                "layers": {
                    "self_attn.q_proj": block.self_attn.q_proj,
                    "self_attn.k_proj": block.self_attn.k_proj,
                    "self_attn.v_proj": block.self_attn.v_proj,
                },
                "prev_op": [block.input_layernorm],
                "input": ["self_attn.q_proj"],
                "inspect": block.self_attn,
                "has_kwargs": True,
            },
            {
                "layers": {"self_attn.o_proj": block.self_attn.o_proj},
                "prev_op": [block.self_attn.v_proj],
                "input": ["self_attn.o_proj"],
                "inspect": block.self_attn.o_proj,
                "has_kwargs": False,
            },
            {
                "layers": {
                    "mlp.gate_proj": block.mlp.gate_proj,
                    "mlp.up_proj": block.mlp.up_proj,
                },
                "prev_op": [block.post_attention_layernorm],
                "input": ["mlp.gate_proj"],
                "inspect": block.mlp,
                "has_kwargs": False,
            },
            {
                "layers": {"mlp.down_proj": block.mlp.down_proj},
                "prev_op": [block.mlp.up_proj],
                "input": ["mlp.down_proj"],
                "inspect": block.mlp.down_proj,
                "has_kwargs": False,
            },
        ]
