import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import LlamaAttention


class DASAttention(nn.Module):

    def __init__(self, original_attention, config):
        super().__init__()
        self.config = original_attention.config
        self.hidden_size = original_attention.hidden_size
        self.num_heads = original_attention.num_heads
        self.head_dim = original_attention.head_dim
        self.max_position_embeddings = original_attention.max_position_embeddings

        self.q_proj = original_attention.q_proj
        self.k_proj = original_attention.k_proj
        self.v_proj = original_attention.v_proj
        self.o_proj = original_attention.o_proj
        self.rotary_emb = original_attention.rotary_emb

        self.block_size = config.get("block_size", 64)
        self.theta_B = config.get("theta_B", 0.85)
        self.tau = config.get("tau", 0.9)
        self.theta_V = config.get("theta_V_scale", 1e-4)

    def _compute_mg_mask(self, q_blocks, k_blocks):

        batch_size, n_heads, n_q_blocks, _, head_dim = q_blocks.shape
        _, _, n_k_blocks, _, _ = k_blocks.shape


        q_sim = F.cosine_similarity(q_blocks.unsqueeze(4), q_blocks.unsqueeze(3), dim=-1).mean(dim=[-1, -2])
        k_sim = F.cosine_similarity(k_blocks.unsqueeze(4), k_blocks.unsqueeze(3), dim=-1).mean(dim=[-1, -2])

        q_preserved = q_sim < self.theta_B
        k_preserved = k_sim < self.theta_B

        sp_mask = q_preserved.unsqueeze(3) | k_preserved.unsqueeze(2)

        q_hat = torch.where((q_sim > self.theta_B)[..., None, None], q_blocks.mean(dim=3, keepdim=True), q_blocks)
        k_hat = torch.where((k_sim > self.theta_B)[..., None, None], k_blocks.mean(dim=3, keepdim=True), k_blocks)

        q_rep = q_hat.mean(dim=3)
        k_rep = k_hat.mean(dim=3)

        p_hat = torch.matmul(q_rep, k_rep.transpose(-1, -2)) / (head_dim ** 0.5)
        p_hat_softmax = F.softmax(p_hat, dim=-1)

        sorted_scores, _ = torch.sort(p_hat_softmax.view(batch_size, n_heads, -1), descending=True)
        cumulative_scores = torch.cumsum(sorted_scores, dim=-1)

        threshold_mask = cumulative_scores > self.tau
        threshold_mask = threshold_mask.cumsum(dim=-1) > 0

        stc_flat_mask = ~threshold_mask

        original_indices = p_hat_softmax.view(batch_size, n_heads, -1).argsort(descending=True)
        stc_mask = torch.zeros_like(p_hat_softmax, dtype=torch.bool).view(batch_size, n_heads, -1)
        stc_mask.scatter_(2, original_indices, stc_flat_mask)
        stc_mask = stc_mask.view(batch_size, n_heads, n_q_blocks, n_k_blocks)

        mg_mask = stc_mask | sp_mask
        return mg_mask

    def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, seq_len=q_len)
        query_states, key_states = self._apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        pad_len = (self.block_size - q_len % self.block_size) % self.block_size
        if pad_len > 0:
            query_states = F.pad(query_states, (0, 0, 0, pad_len))
            key_states = F.pad(key_states, (0, 0, 0, pad_len))
            value_states = F.pad(value_states, (0, 0, 0, pad_len))

        padded_len = q_len + pad_len
        n_blocks = padded_len // self.block_size

        q_blocks = query_states.reshape(bsz, self.num_heads, n_blocks, self.block_size, self.head_dim)
        k_blocks = key_states.reshape(bsz, self.num_heads, n_blocks, self.block_size, self.head_dim)
        v_blocks = value_states.reshape(bsz, self.num_heads, n_blocks, self.block_size, self.head_dim)

        mg_mask = self._compute_mg_mask(q_blocks, k_blocks)

        attn_output = torch.zeros_like(query_states)

        for i in range(n_blocks):
            for j in range(n_blocks):
                if not mg_mask[:, :, i, j].any():
                    continue

                qi = q_blocks[:, :, i, :, :]
                kj = k_blocks[:, :, j, :, :]
                vj = v_blocks[:, :, j, :, :]

                attn_weights = torch.matmul(qi, kj.transpose(-1, -2)) / (self.head_dim ** 0.5)

                if attention_mask is not None:
                    block_attention_mask = attention_mask[:, :, i * self.block_size:(i + 1) * self.block_size,
                                           j * self.block_size:(j + 1) * self.block_size]
                    attn_weights = attn_weights + block_attention_mask

                p_eij = F.softmax(attn_weights, dim=-1)
                max_scores, _ = torch.max(p_eij, dim=-1)
                mpv_mask = (max_scores.mean(dim=-1) > self.theta_V)  # Shape: (bsz, n_heads)

                if not mpv_mask.any():
                    continue

                attn_block_output = torch.matmul(p_eij, vj)  # (bsz, n_heads, block_size, head_dim)
                attn_block_output = attn_block_output * mpv_mask[:, :, None, None]

                attn_output[:, :, i * self.block_size:(i + 1) * self.block_size, :] += attn_block_output.sum(
                    dim=2)

        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, None

    def _apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
        cos = cos[position_ids].unsqueeze(1)
        sin = sin[position_ids].unsqueeze(1)
        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)
        return q_embed, k_embed

    def _rotate_half(self, x):
        x1 = x[..., : self.head_dim // 2]
        x2 = x[..., self.head_dim // 2:]
        return torch.cat((-x2, x1), dim=-1)