import sys
import os
 
curr_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, curr_path)
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from autograd_function import SampledDenseMM, SparseDenseMM, ReduceSum
from kernel import sparse_max, sparse_mask_B, sparse_dense_mm_torch
 
def get_low_resolution_logit(Q, K, block_size, mask = None):
    batch_size, seq_len, head_dim = Q.size()
 
    num_block_per_row = seq_len // block_size
    Q_hat = Q.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim = -2)
    K_hat = K.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim = -2)
 
    low_resolution_logit = torch.matmul(Q_hat, K_hat.transpose(-1, -2)) / math.sqrt(head_dim)
 
    return low_resolution_logit
 
def get_block_idxes(low_resolution_logit, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks):
    batch_size, total_blocks_per_row, _ = low_resolution_logit.shape
 
    if initial_prior_diagonal_n_blocks > 0:
        offset = initial_prior_diagonal_n_blocks // 2
        temp_mask = torch.ones(total_blocks_per_row, total_blocks_per_row, device = low_resolution_logit.device)
        diagonal_mask = torch.tril(torch.triu(temp_mask, diagonal = -offset), diagonal = offset)
        low_resolution_logit = low_resolution_logit + diagonal_mask[None, :, :] * 5e3
 
    if initial_prior_first_n_blocks > 0:
        low_resolution_logit[:, :initial_prior_first_n_blocks, :] = low_resolution_logit[:, :initial_prior_first_n_blocks, :] + 5e3
        low_resolution_logit[:, :, :initial_prior_first_n_blocks] = low_resolution_logit[:, :, :initial_prior_first_n_blocks] + 5e3
 
    top_k_vals = torch.topk(low_resolution_logit.reshape(batch_size, -1), num_blocks, dim = -1, largest = True, sorted = False)
    indices = top_k_vals.indices
   
 
    threshold = top_k_vals.values.min(dim = -1).values
    high_resolution_mask = (low_resolution_logit >= threshold[:, None, None]).float()
 
    return indices, high_resolution_mask
 
def mra2_attention(
    Q, K, V, mask, num_blocks,
    approx_mode,
    block_size = 32,
    initial_prior_first_n_blocks = 0,
    initial_prior_diagonal_n_blocks = 0
):
 
    batch_size, num_head, seq_len, head_dim = Q.size()
    meta_batch = batch_size * num_head
 
    assert seq_len % block_size == 0
    num_block_per_row = seq_len // block_size
 
    Q = Q.reshape(meta_batch, seq_len, head_dim)
    K = K.reshape(meta_batch, seq_len, head_dim)
    V = V.reshape(meta_batch, seq_len, head_dim)
    # mask = None if torch.all(mask == 1).item() else mask[:, None, :].repeat(1, num_head, 1).reshape(meta_batch, seq_len)
 
    if mask is not None:
        Q = Q * mask[:, :, None]
        K = K * mask[:, :, None]
        V = V * mask[:, :, None]
 
    if approx_mode == "full":
        low_resolution_logit, token_count, low_resolution_logit_row_max, V_hat = get_low_resolution_logit(Q, K, block_size, mask, V)
    elif approx_mode == "sparse":
        with torch.no_grad():
            low_resolution_logit, token_count, low_resolution_logit_row_max, _ = get_low_resolution_logit(Q, K, block_size, mask)
    else:
        raise Exception()
 
    with torch.no_grad():
        low_resolution_logit_normalized = low_resolution_logit - low_resolution_logit_row_max
        indices, high_resolution_mask = get_block_idxes(low_resolution_logit_normalized, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks)
 
    high_resolution_logit = SampledDenseMM.operator_call(Q, K, indices, block_size = block_size) / math.sqrt(head_dim)
    max_vals, max_vals_scatter = sparse_max(high_resolution_logit, indices, num_block_per_row, num_block_per_row)
    high_resolution_logit = high_resolution_logit - max_vals_scatter
    if mask is not None:
        high_resolution_logit = high_resolution_logit - 1e4 * (1 - sparse_mask_B(mask, indices)[:, :, :, None])
    high_resolution_attn = torch.exp(high_resolution_logit)
    high_resolution_attn_out = SparseDenseMM.operator_call(high_resolution_attn, indices, V, num_block_per_row)
    high_resolution_normalizer = ReduceSum.operator_call(high_resolution_attn, indices, num_block_per_row, num_block_per_row)
 
    if approx_mode == "full":
        low_resolution_attn = torch.exp(low_resolution_logit - low_resolution_logit_row_max - 1e4 * high_resolution_mask) * token_count[:, None, :]
 
        low_resolution_attn_out = torch.matmul(low_resolution_attn, V_hat)[:, :, None, :].repeat(1, 1, block_size, 1).reshape(meta_batch, seq_len, head_dim)
        low_resolution_normalizer = low_resolution_attn.sum(dim = -1)[:, :, None].repeat(1, 1, block_size).reshape(meta_batch, seq_len)
 
        log_correction = low_resolution_logit_row_max.repeat(1, 1, block_size).reshape(meta_batch, seq_len) - max_vals
        if mask is not None:
            log_correction = log_correction * mask
 
        low_resolution_corr = torch.exp(log_correction * (log_correction <= 0).float())
        low_resolution_attn_out = low_resolution_attn_out * low_resolution_corr[:, :, None]
        low_resolution_normalizer = low_resolution_normalizer * low_resolution_corr
 
        high_resolution_corr = torch.exp(- log_correction * (log_correction > 0).float())
        high_resolution_attn_out = high_resolution_attn_out * high_resolution_corr[:, :, None]
        high_resolution_normalizer = high_resolution_normalizer * high_resolution_corr
 
        attn = (high_resolution_attn_out + low_resolution_attn_out) / (high_resolution_normalizer[:, :, None] + low_resolution_normalizer[:, :, None] + 1e-6)
 
    elif approx_mode == "sparse":
        attn = high_resolution_attn_out / (high_resolution_normalizer[:, :, None] + 1e-6)
    else:
        raise Exception()
 
    if mask is not None:
        attn = attn * mask[:, :, None]
 
    attn = attn.reshape(batch_size, num_head, seq_len, head_dim)
 
    return attn
 
def mra2_logit_only(
    Q, K, mask, num_blocks,
    approx_mode,
    block_size = 32,
    initial_prior_first_n_blocks = 0,
    initial_prior_diagonal_n_blocks = 0
):
 
    batch_size, num_head, seq_len, head_dim = Q.size()
    meta_batch = batch_size * num_head
 
    assert seq_len % block_size == 0
    num_block_per_row = seq_len // block_size
 
    Q = Q.reshape(meta_batch, seq_len, head_dim)
    K = K.reshape(meta_batch, seq_len, head_dim)
 
    if mask is not None:
        Q = Q * mask[:, :, None]
        K = K * mask[:, :, None]
    with torch.no_grad():
        low_resolution_logit = get_low_resolution_logit(Q, K, block_size, mask)
    with torch.no_grad():
        indices, high_resolution_mask = get_block_idxes(low_resolution_logit, num_blocks, approx_mode, initial_prior_first_n_blocks, initial_prior_diagonal_n_blocks)
    high_resolution_logit = SampledDenseMM.operator_call(Q, K, indices, block_size = block_size) / math.sqrt(head_dim)
    Identity_matrix = torch.eye(seq_len, device=Q.device).unsqueeze(0).repeat(meta_batch, 1, 1)
    high_resolution_logit = SparseDenseMM.operator_call(high_resolution_logit, indices, Identity_matrix, num_block_per_row)
   
    if approx_mode == "sparse":
        high_resolution_mask = high_resolution_mask.repeat_interleave(block_size, -2).repeat_interleave(block_size, -1)
        high_resolution_logit.masked_fill_(high_resolution_mask < 0.5, -float("Inf"))
        high_resolution_logit = high_resolution_logit.reshape(batch_size, num_head, seq_len, seq_len)
        return high_resolution_logit
    elif approx_mode == "full":
        low_resolution_logit.masked_fill_(high_resolution_mask > 0.5, 0)
        low_resolution_logit = low_resolution_logit.reshape(batch_size, num_head, num_block_per_row, num_block_per_row).repeat_interleave(block_size, -2).repeat_interleave(block_size, -1)
        high_resolution_logit = high_resolution_logit.reshape(batch_size, num_head, seq_len, seq_len)
        return low_resolution_logit + high_resolution_logit
    else:
        raise Exception