from typing import Optional, Tuple
import torch
import math
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# from .model_base import Model_Base
from .lambda_attention import lambda_matmul


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(vec, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]

    vec_embed = (vec * cos) + (rotate_half(vec) * sin)
    return vec_embed


# Efficient implementation using `models/lambda_attention.py`
def attn_forward_factory(
    self, use_lambda_mask, local_branch, global_branch,
    limit_distance, triangle_offset
):

    def limited_distance_forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        padding_mask: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        
        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        dtype = query_states.dtype
        device = query_states.device
        kv_seq_len = key_states.shape[-2]

        # New: here we change the code to store the un-rotated key and value
        # states, as they are useful for stationary attention.
        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
            key_position_ids = torch.cat([past_key_value[2], position_ids], dim=1)
            kv_seq_len += past_key_value[0].shape[-2]
        else:
            key_position_ids = position_ids

        past_key_value = (key_states, value_states, key_position_ids) if use_cache else None

        # inv_freq controls the dtype of rotation phase, which can be large
        self.rotary_emb.inv_freq = self.rotary_emb.inv_freq.to(torch.float32)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        rot_query_states = apply_rotary_pos_emb(
            query_states, cos, sin, position_ids)
        rot_key_states = apply_rotary_pos_emb(
            key_states, cos, sin, key_position_ids)

        if limit_distance is None:
            stationary_key_states = rot_key_states
            stationary_query_states = rot_query_states
        else:
            stationary_key_states = key_states
            effective_limit_distance = min(limit_distance, kv_seq_len-1)
            # fix a bug for generation with different transformers version
            # transformers 4.35.2:
            # stationary_query_states = \
            #     (query_states * cos[effective_limit_distance]) + \
            #     (rotate_half(query_states) * sin[effective_limit_distance])
            # transformers 4.32.1:
            stationary_query_states = \
                (query_states * cos[0, 0, effective_limit_distance]) + \
                (rotate_half(query_states) * sin[0, 0, effective_limit_distance])

        # If use_lambda_mask, we can use an efficient implementation
        if use_lambda_mask:
            headwise_limit = 33000  # magic number set for A100 GPU
            if q_len > headwise_limit:
                for head_i in range(self.num_heads):
                    query_states[:, head_i] = (
                        lambda_matmul(
                            rot_key_states[:, head_i],
                            stationary_key_states[:, head_i],
                            rot_query_states[:, head_i],
                            stationary_query_states[:, head_i],
                            local_branch, global_branch
                        ) / math.sqrt(self.head_dim)
                    ).softmax().matmul(value_states[:, head_i])
            else:
                query_states = (
                    lambda_matmul(
                        rot_key_states,
                        stationary_key_states,
                        rot_query_states,
                        stationary_query_states,
                        local_branch, global_branch
                    ) / math.sqrt(self.head_dim)
                 ).softmax().matmul(value_states)

        # If not use_lambda_mask, we use a costlier implementation
        else:
            for head_i in range(self.num_heads):
                attn_weights = rot_query_states[:, head_i].matmul(
                    rot_key_states[:, head_i].transpose(1, 2)) / math.sqrt(self.head_dim)

                if limit_distance is not None:
                    attn_weights = attn_weights.triu(-local_branch+1+kv_seq_len-q_len)

                    # stationary query and key states for limited distance attention
                    stationary_key_states = key_states
                    effective_limit_distance = min(limit_distance, kv_seq_len-1)
                    stationary_query_states = \
                        (query_states * cos[0, 0, effective_limit_distance]) + \
                        (rotate_half(query_states) * sin[0, 0, effective_limit_distance])

                    # lower-triangular limited distance zone
                    attn_weights += \
                        stationary_query_states[:, head_i].matmul(
                            stationary_key_states[:, head_i].transpose(1, 2)
                        ).tril(-limit_distance+kv_seq_len-q_len
                               ) / math.sqrt(self.head_dim)

                # triangular mask zone
                if triangle_offset != 0:
                    for line_i in range(max(0, global_branch + local_branch -
                                            kv_seq_len + q_len),
                                        q_len):
                        col_high = line_i - local_branch + 1 + kv_seq_len - q_len
                        attn_weights[:, line_i, global_branch: col_high] -= \
                            math.log(col_high - global_branch) * triangle_offset

                if attention_mask is not None:
                    attn_weights = attn_weights + attention_mask[:, 0]
                    attn_weights = torch.max(
                        attn_weights,
                        torch.tensor(torch.finfo(attn_weights.dtype).min,
                                     device=device)
                    )

                query_states[:, head_i] = torch.matmul(
                    F.softmax(
                        attn_weights, dim=-1, dtype=torch.float32).to(dtype),
                    value_states[:, head_i]
                )

            # legacy codes used for extracting attention weights
            # 1. last token attention
            # with open('attn_weights.txt', 'a') as f:
            #     f.write(str(attn_weights[0, -1].cpu().numpy().tolist()) + '\n')
            #     print(attn_weights[0, -1].reshape(-1, 1024).mean(-1))
            # 2. entropy in overall attention
            # import pickle
            # attn_weights = attn_weights[0]
            # p = attn_weights.exp()
            # p = p / p.sum(-1, keepdim=True)
            # entropy = (-p * p.log()).tril().sum(-1)
            # with open('attn_weights.pkl', 'wb') as f:
            #     pickle.dump(entropy.cpu().numpy(), f)
            # exit()
            # 3. implicit position information
            # if print_or_not:
            #     import pickle
            #     import os
            #     if os.path.exists('features.pkl'):
            #         with open('features.pkl', 'rb') as f:
            #             features = pickle.load(f)
            #     else:
            #         features = []
            #     features.append(query_states[0, 0].cpu().numpy())
            #     with open('features.pkl', 'wb') as f:
            #         pickle.dump(features, f)

        attn_output = query_states
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    return limited_distance_forward


# class LLAMA_Model(Model_Base):
#     wrap_module = LlamaDecoderLayer

#     def __init__(
#         self, model_name_or_path, tokenizer_path, max_length, truncation_side,
#         load_in_4bit, device_map,
#         use_lambda_mask, local_branch, global_branch,
#         limit_distance, triangle_offset
#     ):
#         super().__init__(max_length, truncation_side)
#         self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
#         self.model = LlamaForCausalLM.from_pretrained(
#             model_name_or_path,
#             load_in_4bit=load_in_4bit, device_map=device_map
#         )
#         self.load_in_4bit = load_in_4bit
#         self.device_map = device_map
#         self.tokenizer.pad_token = self.tokenizer.eos_token
#         self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

#         # hack arguments
#         self.use_lambda_mask = use_lambda_mask
#         self.local_branch = local_branch
#         self.global_branch = global_branch
#         self.limit_distance = limit_distance
#         self.triangle_offset = triangle_offset

#         for layer_i, hidden_layer in enumerate(self.model.model.layers):
#             attn = hidden_layer.self_attn
#             attn.forward = attn_forward_factory(
#                 attn, use_lambda_mask, local_branch, global_branch,
#                 limit_distance, triangle_offset
#             )

#     def to(self, device):
#         if self.load_in_4bit:
#             return self

#         self.device = device
#         self.model.to(device)
#         return self


def convert_llama_model(model, local_branch, global_branch):
    for layer_i, hidden_layer in enumerate(model.model.layers):
        attn = hidden_layer.self_attn
        attn.forward = attn_forward_factory(
            attn, True, local_branch, global_branch, local_branch, 0
        )
    return model
