from transformers import LlamaForCausalLM, AutoTokenizer
import torch
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, ALL_ATTENTION_FUNCTIONS, logger
# sparse attn
from .minfer.minfer_attention import minfer_attention
from .flex_prefill_attention import flex_prefill_attention
from .xattention.x_attention import Xattention_prefill, llama_fuse_16
from .auxhead_attention import auxhead_attention
import json

def load_llama(model, attn_temp, gamma, bias, stride, budget_save_dir):
    # get config
    if attn_temp == 'minference':
        cur_attn_config = {
            "type": "minference",
            "budget": json.load(open("./ops/minfer/config/Llama_3.1_8B_Instruct_128k_kv_out_v32_fit_o_best_pattern_v2.json", 'r'))
        }
    elif attn_temp == 'flexprefill':
        cur_attn_config = {
            "type": "flexprefill",
            "gamma": 0.95,
            "tau": 0.1,
            "min_budget": 2048,
            "max_budget": None,
        }
    elif attn_temp == 'xattention':
        cur_attn_config = {
            "type": "xattention",
            "stride": 16,
            "budget": torch.tensor(llama_fuse_16),
        }
    elif attn_temp == 'auxhead':
        cur_attn_config = {
            "type": "auxhead",
            "gamma": gamma,
            "bias": bias,
            "stride": stride,
        }
    elif attn_temp == 'full':
        cur_attn_config = {
            "type": "full",
        }
    # get attn class
    for layer in model.model.layers:
        layer.self_attn.attn_config = cur_attn_config
        layer.self_attn.budget_save_dir = budget_save_dir
        layer.self_attn.forward = llama_sparse_forward.__get__(layer.self_attn)
    return model

def llama_sparse_forward(
    self,
    hidden_states: torch.Tensor,
    position_embeddings,
    attention_mask,
    past_key_value,
    cache_position,
    **kwargs,
):
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.head_dim)

    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # query_states [batchsize, head_num, q_seq_length, dim]
    # key_states [batchsize, head_num, q_seq_length, dim]
    # key_states [batchsize, head_num, q_seq_length, dim]

    # attention_interface = eager_attention_forward
    # if self.config._attn_implementation != "eager":
    #     attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
    if query_states.size(2) > 1 and self.attn_config['type'] != "full":
        if self.attn_config['type'] == "minference":
            attn_output = minfer_attention(query_states,
                                            key_states,
                                            value_states,
                                            self.attn_config['budget'][self.layer_idx],
                                            sparse_file=self.budget_save_dir).transpose(1, 2)
        elif self.attn_config['type'] == "flexprefill":
            attn_output = flex_prefill_attention(   query_states.transpose(1, 2),
                                                    key_states.transpose(1, 2),
                                                    value_states.transpose(1, 2),
                                                    self.attn_config['gamma'],
                                                    self.attn_config['tau'],
                                                    self.attn_config['min_budget'],
                                                    self.attn_config['max_budget'],
                                                    sparse_file=self.budget_save_dir,
                                                    )
        elif self.attn_config['type'] == "xattention":
            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)
            attn_output = Xattention_prefill(query_states, key_states, value_states,
                                                stride=self.attn_config['stride'],
                                                norm=1,
                                                threshold=self.attn_config['budget'][self.layer_idx],
                                                use_triton=True,
                                                sparse_file=self.budget_save_dir,
                                                ).transpose(1, 2)
        elif self.attn_config['type'] == "auxhead":
            attn_output = auxhead_attention(query_states, key_states, value_states,
                                                gamma=self.attn_config['gamma'],
                                                stride=self.attn_config['stride'],
                                                sparse_file=self.budget_save_dir,
                                                bias=self.attn_config['bias'],
                                                )
    else:
        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )
    # produce output
    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output = self.o_proj(attn_output)
    return attn_output, None


if __name__ == "__main__":
    pass
