from typing import *
from enum import Enum
import torch
import math
from .. import SparseTensor
from .. import DEBUG, ATTN

if ATTN == 'xformers':
    import xformers.ops as xops
elif ATTN == 'flash_attn':
    import flash_attn
else:
    raise ValueError(f"Unknown attention module: {ATTN}")


__all__ = [
    'sparse_serialized_scaled_dot_product_self_attention',
]


class SerializeMode(Enum):
    Z_ORDER = 0
    Z_ORDER_TRANSPOSED = 1
    HILBERT = 2
    HILBERT_TRANSPOSED = 3


SerializeModes = [
    SerializeMode.Z_ORDER,
    SerializeMode.Z_ORDER_TRANSPOSED,
    SerializeMode.HILBERT,
    SerializeMode.HILBERT_TRANSPOSED
]


def calc_serialization(
    tensor: SparseTensor,
    window_size: int,
    serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
    shift_sequence: int = 0,
    shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
    """
    Calculate serialization and partitioning for a set of coordinates.

    Args:
        tensor (SparseTensor): The input tensor.
        window_size (int): The window size to use.
        serialize_mode (SerializeMode): The serialization mode to use.
        shift_sequence (int): The shift of serialized sequence.
        shift_window (Tuple[int, int, int]): The shift of serialized coordinates.

    Returns:
        (torch.Tensor, torch.Tensor): Forwards and backwards indices.
    """
    fwd_indices = []
    bwd_indices = []
    seq_lens = []
    seq_batch_indices = []
    offsets = [0]
    
    if 'vox2seq' not in globals():
        import vox2seq

    # Serialize the input
    serialize_coords = tensor.coords[:, 1:].clone()
    serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
    if serialize_mode == SerializeMode.Z_ORDER:
        code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
    elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
        code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
    elif serialize_mode == SerializeMode.HILBERT:
        code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
    elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
        code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
    else:
        raise ValueError(f"Unknown serialize mode: {serialize_mode}")
    
    for bi, s in enumerate(tensor.layout):
        num_points = s.stop - s.start
        num_windows = (num_points + window_size - 1) // window_size
        valid_window_size = num_points / num_windows
        to_ordered = torch.argsort(code[s.start:s.stop])
        if num_windows == 1:
            fwd_indices.append(to_ordered)
            bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
            fwd_indices[-1] += s.start
            bwd_indices[-1] += offsets[-1]
            seq_lens.append(num_points)
            seq_batch_indices.append(bi)
            offsets.append(offsets[-1] + seq_lens[-1])
        else:
            # Partition the input
            offset = 0
            mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
            split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
            bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
            for i in range(num_windows):
                mid = mids[i]
                valid_start = split[i]
                valid_end = split[i + 1]
                padded_start = math.floor(mid - 0.5 * window_size)
                padded_end = padded_start + window_size
                fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
                offset += valid_start - padded_start
                bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
                offset += padded_end - valid_start
                fwd_indices[-1] += s.start
            seq_lens.extend([window_size] * num_windows)
            seq_batch_indices.extend([bi] * num_windows)
            bwd_indices.append(bwd_index + offsets[-1])
            offsets.append(offsets[-1] + num_windows * window_size)

    fwd_indices = torch.cat(fwd_indices)
    bwd_indices = torch.cat(bwd_indices)

    return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
    

def sparse_serialized_scaled_dot_product_self_attention(
    qkv: SparseTensor,
    window_size: int,
    serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
    shift_sequence: int = 0,
    shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> SparseTensor:
    """
    Apply serialized scaled dot product self attention to a sparse tensor.

    Args:
        qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
        window_size (int): The window size to use.
        serialize_mode (SerializeMode): The serialization mode to use.
        shift_sequence (int): The shift of serialized sequence.
        shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
        shift (int): The shift to use.
    """
    assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"

    serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
    serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
    if serialization_spatial_cache is None:
        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
        qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
    else:
        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache

    M = fwd_indices.shape[0]
    T = qkv.feats.shape[0]
    H = qkv.feats.shape[2]
    C = qkv.feats.shape[3]
    
    qkv_feats = qkv.feats[fwd_indices]      # [M, 3, H, C]

    if DEBUG:
        start = 0
        qkv_coords = qkv.coords[fwd_indices]
        for i in range(len(seq_lens)):
            assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
            start += seq_lens[i]

    if all([seq_len == window_size for seq_len in seq_lens]):
        B = len(seq_lens)
        N = window_size
        qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
        if ATTN == 'xformers':
            q, k, v = qkv_feats.unbind(dim=2)                       # [B, N, H, C]
            out = xops.memory_efficient_attention(q, k, v)          # [B, N, H, C]
        elif ATTN == 'flash_attn':
            out = flash_attn.flash_attn_qkvpacked_func(qkv_feats)   # [B, N, H, C]
        else:
            raise ValueError(f"Unknown attention module: {ATTN}")
        out = out.reshape(B * N, H, C)                              # [M, H, C]
    else:
        if ATTN == 'xformers':
            q, k, v = qkv_feats.unbind(dim=1)                       # [M, H, C]
            q = q.unsqueeze(0)                                      # [1, M, H, C]
            k = k.unsqueeze(0)                                      # [1, M, H, C]
            v = v.unsqueeze(0)                                      # [1, M, H, C]
            mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
            out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
        elif ATTN == 'flash_attn':
            cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
                        .to(qkv.device).int()
            out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]

    out = out[bwd_indices]      # [T, H, C]

    if DEBUG:
        qkv_coords = qkv_coords[bwd_indices]
        assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"

    return qkv.replace(out)
