import torch
from functools import reduce, lru_cache
from operator import mul


def window_partition(x, window_size):
    """
    Args:
        x: (B, D, H, W, C)
        window_size (tuple[int]): window size
    Returns:
        windows: (B*num_windows, window_size*window_size, C)
    """
    B, D, H, W, C = x.shape
    x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2],
               window_size[2], C)
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
    return windows


def window_reverse(windows, window_size, B, D, H, W):
    """
    Args:
        windows: (B*num_windows, window_size, window_size, C)
        window_size (tuple[int]): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, D, H, W, C)
    """
    x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1],
                     window_size[2], -1)
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
    return x

def get_window_size(x_size, window_size, shift_size=None):
    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)


# cache each stage results
@lru_cache()
def generate_shifted_window_mask(D, H, W, window_size, shift_size, device=None):
    img_mask = torch.zeros((1, D, H, W, 1), device=device)  # 1 Dp Hp Wp 1
    cnt = 0
    for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
        for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
            for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
                img_mask[:, d, h, w, :] = cnt
                cnt += 1
    mask_windows = window_partition(img_mask, window_size)  # nW, ws[0]*ws[1]*ws[2], 1
    mask_windows = mask_windows.squeeze(-1)  # nW, ws[0]*ws[1]*ws[2]
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

@lru_cache()
def generate_causal_mask(window_size, span=1, device=None):
    assert window_size % span == 0, '...'
    n = window_size//span
    mask = torch.ones(n, n, dtype=torch.float32, device=device)
    causal_mask = torch.tril(mask)
    causal_mask = causal_mask.view(n, 1, n, 1).expand(n, span, n, span)
    return causal_mask.reshape(window_size, window_size)

def process_rollout_attention(attn_weights, discard_ratio=0.9, fusion='max'):
    assert fusion in {'average', 'min', 'max'}, '...'
    if fusion == 'average':
        attn = attn_weights.mean(1)
    elif fusion == 'min':
        attn = attn_weights.min(1).values
    else:
        attn = attn_weights.max(1).values
    batch_size, num_tokens = attn.shape[:2]
    
    flat = attn.view(attn.size(0), -1)
    _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
    indices = indices[indices != 0]
    flat[0, indices] = 0
            
    skip = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(attn.device)
    attn = (attn + skip)
    attn /= attn.sum(dim=-1, keepdim=True)
    return attn
