import torch
import dataclasses as dc
from litgpt.model import (
    GPT,
    KVCache as _LitKVCache,
    CausalSelfAttention,
    build_mask_cache,
    Block,
)
from typing import Iterable, Callable, Sequence, Any, cast
from .th import TensorDataClass, TensorList, clone_to
import math


def _iter_blocks(gpt: GPT) -> Iterable[Block]:
    """Iterate transformer blocks in the GPT model."""
    return cast(Iterable[Block], gpt.transformer.h)


def _kv_buf_write(
    buf: torch.Tensor,  # [N, head, T, embedding]
    pos: torch.Tensor, # indices of token positions: [len] or [batch, len]
    emb: torch.Tensor,  # embeddings: [batch, head, len, embeddings]
    b: int | torch.Tensor | None, # batch size, or indices with batch: [batch]
):
    pos_dim = -2

    assert buf.ndim == 4 and pos.ndim <= 2
    if pos.dim() == 1:
        if b is None:
            return buf.index_copy_(pos_dim, pos, emb)
        elif isinstance(b, int):
            return buf[:b].index_copy_(pos_dim, pos, emb)
        else:
            assert b.dtype == torch.int64 and b.ndim == 1 and b.size(0) == emb.size(0)
            buf_b = torch.index_select(buf, 0, b)
            buf_b.index_copy_(pos_dim, pos, emb)
            buf.index_copy_(0, b, buf_b)
            return buf_b
    else:
        _out_idx: torch.Tensor | slice | None
        if isinstance(b, torch.Tensor):
            assert b.dtype == torch.int64 and b.ndim == 1, "batch indices must be 1-dimensional index tensor"
            _batch_indices = b.tolist()
            _out_idx = b
        elif isinstance(b, int):
            b = min(emb.size(0), b)
            _batch_indices = range(b)
            _out_idx = slice(b)
        else:
            _batch_indices = range(emb.size(0))
            _out_idx = None
        assert len(_batch_indices) == pos.size(0) == emb.size(0), "batch sizes of arguments must be consistent"
        for i in _batch_indices:
            buf[i].index_copy_(pos_dim, pos[i], emb[i])
        return buf if _out_idx is None else buf[_out_idx, ...]


class KVCache(_LitKVCache):

    _batch_pos: torch.Tensor | None

    def __init__(self,
                 k_shape: tuple[int, int, int, int], 
                 v_shape: tuple[int, int, int, int],
                 device: torch.device | None = None,
                 dtype: torch.dtype | None = None,) -> None:
        
        super().__init__(k_shape, v_shape, device, dtype)
        self.register_buffer("_batch_pos", None, persistent=False)

    def forward(
        self,
        input_pos: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor, 
    ) -> tuple[torch.Tensor, torch.Tensor]:
        b = self._batch_pos
        if b is None:
            b = k.size(0)
        self._batch_pos = None

        self.k = self.k.to(k.dtype)
        self.v = self.v.to(v.dtype)
        k = _kv_buf_write(self.k, input_pos, k, b)
        v = _kv_buf_write(self.v, input_pos, v, b)
        return k, v

    def reset_parameters(self) -> None:
        super().reset_parameters()
        self._batch_pos = None


def mask_kv_caches(gpt: GPT, mask: torch.Tensor | None):
    """
    Mask the kv-cache so that only a part of cache is used.
    Note that the mask will be cleared after the forward pass.
    """
    if isinstance(mask, torch.Tensor):
        assert mask.dtype == torch.bool
        pos = mask.reshape(-1).nonzero().squeeze(1)
    for cache in iterate_kv_caches(gpt):
        cache._batch_pos = pos


def build_kv_caches(
    gpt: GPT,
    batch_size: int,
    max_seq_length: int | None = None,
    rope_cache_length: int | None = None,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None
) -> None:
    if rope_cache_length is None:
        rope_cache_length = gpt.cos.size(-1)
    if max_seq_length is None:
        max_seq_length = gpt.max_seq_length
    for block in _iter_blocks(gpt):
        block.attn.kv_cache = _build_kv_cache(block.attn,
            batch_size, max_seq_length, rope_cache_length, device, dtype
        )
    if gpt.mask_cache is None or gpt.mask_cache.size(3) != max_seq_length:
        gpt.mask_cache = build_mask_cache(max_seq_length, device)


def _build_kv_cache(
    attn: CausalSelfAttention,
    batch_size: int,
    max_seq_length: int,
    rope_cache_length: int | None = None,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None
) -> KVCache:
    heads = 1 if attn.config.n_query_groups == 1 else attn.config.n_head
    head_size = attn.config.head_size
    assert head_size is not None
    v_shape = (batch_size, heads, max_seq_length, head_size)
    if rope_cache_length is None:
        if attn.config.rotary_percentage != 1.0:
            raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
        k_shape = v_shape
    else:
        k_shape = (
            batch_size,
            heads,
            max_seq_length,
            rope_cache_length + head_size - attn.config.rope_n_elem,
        )
    return KVCache(k_shape, v_shape, device=device, dtype=dtype)


def iterate_kv_caches(model: GPT, invalid_ok=False) -> Iterable[KVCache]:
    for block in _iter_blocks(model):
        kv_cache = block.attn.kv_cache
        if not isinstance(kv_cache, KVCache):
            if invalid_ok:
                return
            else:
                raise ValueError("KV-cache has not been set properly.")
        yield kv_cache


def enumerate_kv_caches(model: GPT, invalid_ok=False) -> Iterable[tuple[int, KVCache]]:
    for i, block in enumerate(_iter_blocks(model)):
        kv_cache = block.attn.kv_cache
        if not isinstance(kv_cache, KVCache):
            if invalid_ok:
                return
            else:
                raise ValueError("KV-cache has not been set properly.")
        yield i, kv_cache


def iterate_kv_tensors(model: GPT, invalid_ok=False) -> Iterable[tuple[torch.Tensor, torch.Tensor]]:
    for block in _iter_blocks(model):
        kv_cache = block.attn.kv_cache
        if not isinstance(kv_cache, KVCache):
            if invalid_ok:
                return
            else:
                raise ValueError("KV-cache has not been set properly.")
        k = kv_cache.k
        v = kv_cache.v
        yield (k, v)


@dc.dataclass
class KVTensors(TensorDataClass):

    pos: slice | list[int]
    ks: TensorList
    vs: TensorList

    @property
    def n_layers(self):
        assert len(self.ks) == len(self.vs)
        return len(self.ks)

    @staticmethod
    def extract_from(
        model: GPT,
        batch_shape: int | Sequence[int], 
        pos: slice | list[int] | torch.Tensor,
        device: torch.device | str | None = None,
        clone: bool = True
    ):

        size = batch_shape if isinstance(batch_shape, int) else math.prod(batch_shape)
        batch_shape = (batch_shape,) if isinstance(batch_shape, int) else batch_shape

        def get(x: torch.Tensor):
            assert x.ndim == 4
            y = x[:size, :, pos, :]
            y = y.reshape(*batch_shape, *y.shape[1:])
            if clone and device:
                return clone_to(y, device)
            elif clone:
                return y.clone()
            elif device:
                return y.to(device=device)
            else:
                return y

        ks = TensorList()
        vs = TensorList()
        for kv_cache in iterate_kv_caches(model):
            k = get(kv_cache.k)
            v = get(kv_cache.v)
            ks.append(k)
            vs.append(v)
        
        if isinstance(pos, torch.Tensor):
            assert pos.ndim == 1 and pos.dtype == torch.int64
            pos = pos.tolist()
            
        return KVTensors(pos, ks, vs)

    def insert_to(self, model: GPT):
        pos = self.pos

        def set_(dst: torch.Tensor, src: torch.Tensor):
            if src.ndim > 4:
                src = src.flatten(0, -4)
            assert dst.ndim == src.ndim == 4
            size = src.size(0)
            assert size <= dst.size(0)
            dst[:size, :, pos, :] = src

        for cache, k, v in zip(iterate_kv_caches(model), self.ks, self.vs, strict=True):
            set_(cache.k, k)
            set_(cache.v, v)


def zero_kv_cache(model: GPT, invalid_ok=False):
    for kv_cache in iterate_kv_caches(model, invalid_ok):
        kv_cache.reset_parameters()


def get_kv_size(model: GPT):
    """
    Acquire the size of KV-cache with (batchsize, max_length).
    Returns (-1, -1) if KV-cache has not been set.
    """

    for k, _ in iterate_kv_tensors(model, invalid_ok=True):
        return k.size(0), k.size(2)
    
    return -1, -1


def reset(
    model: GPT,
    batch_size: int,
    context_length: int,
    device=None,
    dtype=None,
    lazy: bool = True,
):
    if not (batch_size > 0 and context_length > 0):
        raise ValueError

    if not lazy:
        reassign = True
    else:
        b, maxlen = get_kv_size(model)
        reassign = b < batch_size or maxlen < context_length

    if reassign:
        model.clear_kv_cache()
        build_kv_caches(model, batch_size, context_length, device=device, dtype=dtype)
        return batch_size, context_length
    else:
        zero_kv_cache(model)
        return b, maxlen


def _shape_size(size: Sequence[int] | int) -> tuple[torch.Size, int]:
    if isinstance(size, int):
        return torch.Size((size,)), size
    if not isinstance(size, torch.Size):
        size = torch.Size(size)
    return size, math.prod(size)


def kv_apply(
    gpt: GPT,
    func: Callable[[torch.Tensor], torch.Tensor],
    in_shape: Sequence[int] | int | None = None,
    pos: torch.Tensor | slice | None = None,
):
    """
    func: Tensor[*in_shape, heads, length, feature] -> Tensor[*out_shape, heads, length, feature]
    in_shape: by default: batch_size.
    """

    # no operation needed if `func` is `clone`.
    if func is torch.clone:
        return

    shape_size = None if in_shape is None else _shape_size(in_shape)
    if pos is None:
        pos = slice(None)

    def update(x: torch.Tensor):

        elem_ndim = x.ndim - 1
        if shape_size is not None:
            shape, size = shape_size
            if size > x.size(0):
                raise IndexError("The input size exceeds the maximal batch_size of KV cache.")
            y = x[:size, :, pos, :]
            y = y.view(shape + y.shape[1:])
        else:
            y = x[:, :, pos, :]
        
        y = func(y)
        y_batch_ndim = y.ndim - elem_ndim
        shape, size = _shape_size(y.shape[:y_batch_ndim])
        if size > x.size(0):
            raise IndexError("The result size exceeds the maximal batch_size of KV cache.")
        
        x[:size, :, pos, :].copy_(y.reshape(size, *y.shape[y_batch_ndim:]))

    for k, v in iterate_kv_tensors(gpt):
       update(k)
       update(v)


def kv_inplace_apply(
    gpt: GPT,
    func: Callable[[torch.Tensor], Any],
    in_shape: Sequence[int] | int | None = None,
    pos: torch.Tensor | slice | None = None,
):
    """
    func: Tensor[*in_shape, heads, length, feature] -> Tensor[*out_shape, heads, length, feature]
    in_shape: by default: batch_size.
    """

    shape_size = None if in_shape is None else _shape_size(in_shape)
    
    if pos is None:
        pos = slice(None)

    if shape_size is not None:

        shape, size = shape_size

        def update(x: torch.Tensor):
            if size > x.size(0):
                raise IndexError("The input size exceeds the maximal batch_size of KV cache.")
            y = x[:size, :, pos, :]
            y = y.view(shape + y.shape[1:])
            func(y)
            # If `pos` is a tensor, `y` does not refer to the same memory of `x`.
            # We need to copy the data in this case.
            if y.data_ptr() != x.data_ptr():  
                x[:size, :, pos, :] = y.view((size,) + y.shape[1:])
        
    else:

        def update(x: torch.Tensor):
            y = x[:, :, pos, :]
            assert y.data_ptr() == x.data_ptr()
            func(y)

    for k, v in iterate_kv_tensors(gpt):
       update(k)
       update(v)
