import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers import LlamaConfig
import transformers
from safetensors import safe_open

logger = logging.get_logger(__name__)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class NTKConfig:
    poly_order = 2
    r = "adapt"


elu = nn.ELU()


def phi_order1(x):
    return elu(x / (x.size(-1) ** 0.25)) + 1


def phi_order2(x):
    x = x.float() / x.size(-1) ** 0.25
    first_order = elu(x) + 1
    second_order = (x ** 2) / 2
    return torch.cat([first_order, second_order], dim=-1)


def phi_order3(x):
    x = x.float() / x.size(-1) ** 0.25
    first_order = elu(x) + 1
    second_order = (x ** 2) / 2
    third_order = torch.exp(x) * (x < 0) + ((x ** 3) / 6) * (x >= 0)
    return torch.cat([first_order, second_order, third_order], dim=-1)


def phi_order4(x):
    x = x.float() / x.size(-1) ** 0.25
    first_order = elu(x) + 1
    second_order = (x ** 2) / 2
    third_order = torch.exp(x) * (x < 0) + ((x ** 3) / 6) * (x >= 0)
    fourth_order = (x ** 4) / 24
    return torch.cat([first_order, second_order, third_order, fourth_order], dim=-1)


class LlamaNTKAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, ntk_config: NTKConfig, layer_idx: Optional[int] = None):
        super().__init__()

        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
        self.rotary_emb = None

        order = ntk_config.poly_order
        r = ntk_config.r

        if r == "adapt":
            self.r = (self.head_dim - order) // (order + 1)
        elif r == "optimal":
            self.r = self.head_dim // 2
        else:
            self.r = r

        self.phi = {
            1: phi_order1,
            2: phi_order2,
            3: phi_order3,
            4: phi_order4
        }[order]

        self.k = nn.Parameter(torch.randn(1, self.num_heads, self.head_dim * order, 1))
        self.Z_A = nn.Parameter(torch.randn(1, self.num_heads, self.head_dim * order, self.r))
        self.Z_B = nn.Parameter(torch.zeros(1, self.num_heads, self.r, self.head_dim))
        self.Z = None
        # self.Z_B = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim // 2, self.head_dim))
        # self._init_rope()

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        past_key_value = getattr(self, "past_key_value", past_key_value)
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = attn_weights.float()
        phi_Q = self.phi(query_states.float())
        # print("use zk")
        k = self.k.abs().float()
        if not self.training:
            Z = self.Z.float()
        else:
            Z = torch.matmul(self.Z_A, self.Z_B).float() * self.head_dim / self.r
        max_attn_weights = attn_weights.max(-1).values.unsqueeze(-1)
        A = torch.exp(attn_weights - max_attn_weights)
        exp_max_attn_weights = torch.exp(max_attn_weights)
        D = A.sum(-1).unsqueeze(-1) + torch.matmul(phi_Q, k) / exp_max_attn_weights
        attn_weights = A / D

        # upcast attention to fp32
        # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = (torch.matmul(attn_weights, value_states) + torch.matmul(phi_Q, Z) / D).to(
            query_states.dtype)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    def load_qkvo(self, q, k, v, o, rotary_emb):
        self.q_proj.load_state_dict(q.state_dict())
        self.k_proj.load_state_dict(k.state_dict())
        self.v_proj.load_state_dict(v.state_dict())
        self.o_proj.load_state_dict(o.state_dict())
        self.rotary_emb = rotary_emb

        self.q_proj.requires_grad_(False)
        self.k_proj.requires_grad_(False)
        self.v_proj.requires_grad_(False)
        self.o_proj.requires_grad_(False)

    def load_zk(self, Z_A, Z_B, k):

        self.Z_A = nn.Parameter(Z_A.clone())
        self.Z_B = nn.Parameter(Z_B.clone())

        if not self.training:
            self.Z = torch.matmul(self.Z_A, self.Z_B).float() * self.head_dim / self.r
        self.k = nn.Parameter(k.clone())


def prepare_model(path, cache_dir, settings):
    model = transformers.AutoModelForCausalLM.from_pretrained(path, cache_dir=cache_dir)
    config = LlamaConfig.from_pretrained(path)

    for name, param in model.named_parameters():
        param.requires_grad = False

    for i in range(config.num_hidden_layers):
        q = model.model.layers[i].self_attn.q_proj
        k = model.model.layers[i].self_attn.k_proj
        v = model.model.layers[i].self_attn.v_proj
        o = model.model.layers[i].self_attn.o_proj
        rotary_emb = model.model.layers[i].self_attn.rotary_emb

        model.model.layers[i].self_attn = LlamaNTKAttention(config, settings, i)
        model.model.layers[i].self_attn.load_qkvo(q, k, v, o, rotary_emb)
    return model


def load_model(path, settings):
    model = transformers.AutoModelForCausalLM.from_pretrained(path)
    config = LlamaConfig.from_pretrained(path)

    model.eval()

    tensors = {}
    with safe_open(path + "/model.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)

    for i in range(config.num_hidden_layers):
        q = model.model.layers[i].self_attn.q_proj
        k = model.model.layers[i].self_attn.k_proj
        v = model.model.layers[i].self_attn.v_proj
        o = model.model.layers[i].self_attn.o_proj
        rotary_emb = model.model.layers[i].self_attn.rotary_emb

        model.model.layers[i].self_attn = LlamaNTKAttention(config, settings, i)
        model.model.layers[i].self_attn.load_qkvo(q, k, v, o, rotary_emb)

        Z_A = tensors['model.layers.' + str(i) + '.self_attn.Z_A']
        Z_B = tensors['model.layers.' + str(i) + '.self_attn.Z_B']
        k = tensors['model.layers.' + str(i) + '.self_attn.k']
        model.model.layers[i].self_attn.load_zk(Z_A, Z_B, k)
        # print(Z, k)

    return model
