import copy
from dataclasses import dataclass, field
import os
from torch import Tensor
from typing import Dict, Literal, Optional, List, TYPE_CHECKING

import torch
if TYPE_CHECKING:
    from hip.models.hip_attention.gen3.uvm_gpu_cache import (
        HiPOffloadCache
    )

def safe_stride(x: Optional[Tensor], ndim: int):
    if x is None:
        return tuple([0,] * ndim)
    else:
        stride = x.stride()
        assert len(stride) == ndim
        return stride

@dataclass
class Stage:
    stage_block_size_q: int
    stage_block_stride_q: int
    stage_chunk_size: int
    stage_k: Optional[int]
    stage_stride: int
    
    require_realign_index: bool = False
    require_reset_score: bool = False
    require_post_sort: bool = False

@dataclass
class NopStage(Stage):
    require_realign_index: bool = True
    require_reset_score: bool = False
    require_post_sort: bool = True

@dataclass
class EvalScoreStage(Stage):
    block_chunk: int = 64
    stage_extend_backend: Optional[str] = None
    require_reset_score: bool = True
    require_post_sort: bool = True

@dataclass
class ScanStage(Stage):
    stage_extend_backend: Optional[str] = None
    require_realign_index: bool = True
    require_reset_score: bool = True
    require_post_sort: bool = True

@dataclass
class EnsembleScoreStage(Stage):
    reduce_method: str = 'sum'
    require_reset_score: bool = True
    require_post_sort: bool = True

StatKeys = Literal[
    'unique_access_count',
    'access_count',
    'cache_miss_count',
    'cache_hit_ratio',
]

@dataclass
class HiPAttentionCacheAccessStatistics:
    # [BSZ, HEAD_KV, MAX_TSRC]
    access_counter: Tensor
    # [BSZ, HEAD_KV, MAX_TSRC]
    cache_miss_counter: Tensor
    
    def compute_statistics(self) -> Dict[
        StatKeys,
        Tensor,
    ]:
        # FIXME: 
        if (os.getenv('HIP_DISABLE_COMPUTE_STATISTICS', '1') == '0') and (self.access_counter is not None):
            unique_access_count = self.access_counter.clamp(0, 1).sum()
            access_counts = self.access_counter.sum()
            cache_miss_counts = self.cache_miss_counter.sum()
            cache_hit_ratio = 1 - (cache_miss_counts / access_counts)
        else:
            unique_access_count = None
            access_counts = None
            cache_miss_counts = None
            cache_hit_ratio = None
        
        return {
            'unique_access_count': unique_access_count,
            'access_count': access_counts,
            'cache_miss_count': cache_miss_counts,
            'cache_hit_ratio': cache_hit_ratio,
        }

@dataclass
class HiPAttentionStageInputCache:
    indices_left: Tensor
    indices_right: Tensor
    out_scores: Tensor

@dataclass
class HiPAttentionOutputMetadata:
    indices: Optional[Tensor]
    ks: Optional[Tensor]
    ks_count: Optional[Tensor]
    ks_start_end: Optional[Tensor]
    
    # memory access statistics
    mask_cache_statistics: Optional[HiPAttentionCacheAccessStatistics]
    sa_cache_statistics: Optional[HiPAttentionCacheAccessStatistics]
    
    # stage caches
    stage_caches: Optional[List[HiPAttentionStageInputCache]]
    
@dataclass
class HiPAttentionArgs:
    position_ids: Optional[Tensor] = None
    
    sink_token_size: int = 256
    sliding_window_size: int = 512
    block_size_k: int = 64  # for optimization this will be BLOCK_CHUNK
    
    block_size_q: int = 64  # no effect, set automatically
    mask_k: int = 512       # no effect, set automatically
    
    second_stage_k: int = 2048
    stages: List[Stage] = field(
        default_factory=lambda: [
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=256,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=32,
                stage_k=32768,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=8,
                stage_k=8192,
                stage_stride=1,
            ),
        ]
    )
    model_context_length: int = 131072
    extend_context_length: int = 512 * 1024
    
    # kernel args,
    mask_only: bool = False
    block_sparse_block_size_q: Optional[int] = 64
    
    scan_early_terminate: int = 1
    stage_early_terminate: int = 1
    scan_extend_backend: str = 'relative'
    sa_extend_backend: str = 'streaming'
    low_percent: float = 0.0
    low_k_ratio: float = 1.0
    dim_to_lower: Literal['head', 'seq'] = 'head'
    q_mask: Optional[Tensor] = None
    k_mask: Optional[Tensor] = None
    idx_pca_hid_q: Optional[Tensor] = None
    idx_pca_hid_k: Optional[Tensor] = None
    
    is_causal: bool = True
    
    using_extend: bool = False
    need_apply_rope: bool = False
    rope_cos: Optional[Tensor] = None
    rope_sin: Optional[Tensor] = None
    
    offload_cache: "Optional[HiPOffloadCache]" = None
    k_cache: Optional[Tensor] = None
    v_cache: Optional[Tensor] = None
    cache_seq_lens: Optional[Tensor] = None
    block_table: Optional[Tensor] = None
    
    # to support gemma2
    logit_softcap: Optional[float] = None

    online_update_cache: bool = False
    
    require_cache_statistics: bool = True
    
    disable_flashdecode: bool = False
    
    def __post_init__(self):
        if self.rope_cos is not None and self.rope_cos.ndim == 3:
            self.rope_cos = self.rope_cos.view(-1, self.rope_cos.shape[-1])
            self.rope_sin = self.rope_sin.view(-1, self.rope_sin.shape[-1])
        # if self.q_quant is not None:
        #     assert self.q_quant.ndim == 4
        #     assert self.k_quant.ndim == 4
        self.update_flags()
    
    def update_flags(self):
        if self.logit_softcap == 0:
            self.logit_softcap = None
        self.using_paged_cache = (self.k_cache is not None) or (self.offload_cache is not None)
        if self.using_paged_cache:
            if self.k_cache is not None:
                self.paged_cache_page_count = self.k_cache.shape[0]
                self.paged_cache_page_size = self.k_cache.shape[1]
            else:
                self.paged_cache_page_count = self.offload_cache.get_page_count()
                self.paged_cache_page_size = 1
            assert self.paged_cache_page_size in (1, 2, 4, 8, 16, 32)
        if self.logit_softcap == 0:
            self.logit_softcap = None
    
    def clone(self):
        self.update_flags()
        return copy.copy(self)
    
    def json(self, convert_tensor_to_meta = True):
        from dataclasses import fields
        json = {}
        for field in fields(self):
            json[field.name] = getattr(self, field.name)
        
        if convert_tensor_to_meta:
            for k, v in json.items():
                if isinstance(v, Tensor):
                    v = f'{v.dtype}{list(v.shape)}@{v.device}.{v.data_ptr():02X}'
                json[k] = v
        
        return json

    def args_extend(self):
        return (
            self.using_extend,
            self.need_apply_rope,
            *self.args_rope_cos(),
            *self.args_rope_sin(),
        )
    
    def args_rope_cos(self):
        return self.rope_cos, *safe_stride(self.rope_cos, 2),
    
    def args_rope_sin(self):
        return self.rope_sin, *safe_stride(self.rope_sin, 2),
    
    def args_paged_kv_cache(self):
        using_page = self.using_paged_cache
        
        if self.offload_cache is None:
            if using_page:
                assert self.v_cache is not None
                assert self.k_cache.ndim == self.v_cache.ndim
                assert self.k_cache.ndim == 4
                assert self.block_table is not None
                assert self.block_table.ndim == 2
                assert self.cache_seq_lens is not None
                assert self.cache_seq_lens.ndim == 1
                page_size = self.k_cache.shape[1]
            else:
                page_size = 0
            
            return (
                using_page,
                page_size,
                self.k_cache, *safe_stride(self.k_cache, 4),
                self.v_cache, *safe_stride(self.v_cache, 4),
                self.block_table, *safe_stride(self.block_table, 2),
                self.cache_seq_lens, *safe_stride(self.cache_seq_lens, 1),
            )
        else:
            assert using_page

            k_cache = self.offload_cache.k_uvm.bank_gpu.unsqueeze(1)
            v_cache = self.offload_cache.v_uvm.bank_gpu.unsqueeze(1)
            
            return (
                True,
                1,
                k_cache, *safe_stride(k_cache, 4),
                v_cache, *safe_stride(v_cache, 4),
                self.block_table, *safe_stride(self.block_table, 2),
                self.cache_seq_lens, *safe_stride(self.cache_seq_lens, 1),
            )
    
    def args_offload_cache(self, is_masking):
        if self.offload_cache:
            gpu_cache = self.offload_cache.mask_k_cache\
                if is_masking else\
                    self.offload_cache.sa_kv_cache
            is_packed = gpu_cache.kv_packed
            uvm_metadata = self.offload_cache.k_uvm.metadata
            return (
                True,
                is_packed,
                gpu_cache.bank.shape[0],
                uvm_metadata, *safe_stride(uvm_metadata, 2),
                gpu_cache.global_metadata, *safe_stride(gpu_cache.global_metadata, 2),
                gpu_cache.bank, *safe_stride(gpu_cache.bank, 2),
                gpu_cache.metadata, *safe_stride(gpu_cache.metadata, 2),
                gpu_cache.table, *safe_stride(gpu_cache.table, 3),
            )
        else:
            return (
                False,
                False,
                0,
                None, 0, 0,
                None, 0, 0,
                None, 0, 0,
                None, 0, 0, 
                None, 0, 0, 0,
            )