from typing import Optional, Tuple, Unpack, Callable, Dict, Any, List, Union
from dataclasses import dataclass

import torch
from torch import nn
from transformers import Cache
from torch.utils.cpp_extension import load
#############################################################################
# Kernel
#########
libhogatt = load(
    name="libhogatt",
    sources=["../kernels/libhogatt.cu"],
    extra_cflags=["-O3"],
    extra_cuda_cflags=["-O3","-U__CUDA_NO_BFLOAT16_CONVERSIONS__","-U__CUDA_NO_HALF_CONVERSIONS__"],
    verbose=True
)
#torch.ops.load_library("libhogatt.so")

def merge_caches(target: Cache, source: Cache, model):
    locations = torch.full((1, source.get_seq_length()), target.get_seq_length(), dtype=torch.int32).to("cuda")
    cos, sin = model.rotary_emb(source.key_cache[0], locations)
    for layer_id in range(len(target.key_cache)):
        # re-rotate
        key_states = apply_rotary_pos_emb(source.key_cache[layer_id], cos, sin)
        target.update(key_states, source.value_cache[layer_id], layer_id)
        
def hogwild_fused(queries: torch.Tensor, locations: torch.Tensor, keys: list[torch.Tensor], values: list[torch.Tensor],
                  scale: float, fragment_lengths, cosines: torch.Tensor, sines: torch.Tensor, *, rotated_queries, out):
    """Custom rope+attention kernel"""
    torch.ops.libhogatt.hogwild_fused(out, rotated_queries, scale, locations, queries.contiguous(), fragment_lengths, keys, values, cosines, sines)
    return out

def hogwild_sdpa(queries: torch.Tensor, locations: torch.Tensor, keys: list[torch.Tensor], values: list[torch.Tensor],
                 scale: float, fragment_lengths=None, out=None) -> torch.Tensor:
    if out is None:
        out = torch.empty((queries.size(1), queries.size(2), queries.size(3), values[0].size(3)), dtype=queries.dtype, device=queries.device)
    if fragment_lengths is None:
        fragment_lengths = torch.tensor([k.size(2) for k in keys], dtype=torch.int32, device=queries.device)
    keys = [k.contiguous() for k in keys]
    values = [v.contiguous() for v in values]
    torch.ops.libhogatt.hogwild_sdpa(out, scale, locations, queries.contiguous(), fragment_lengths, keys, values)
    return out


def hogwild_rope(queries: torch.Tensor, cosines: torch.Tensor, sines: torch.Tensor, out=None):
    if out is None:
        out = torch.empty((cosines.size(0), queries.size(0), queries.size(1), queries.size(2), queries.size(3)), dtype=queries.dtype, device=queries.device)
    torch.ops.libhogatt.hogwild_rope(out, queries, cosines, sines)
    return out


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    return (x * cos) + (rotate_half(x) * sin)


#############################################################################
# Cache
#########

@dataclass
class InternalCacheMeta:
    cos: list[torch.Tensor] | list[None] | torch.Tensor
    sin: list[torch.Tensor] | list[None] | torch.Tensor
    loc: list[torch.Tensor] | list[None] | torch.Tensor
    cs: Cache = None


@dataclass
class CacheStructure:
    keys: list[torch.Tensor]  # keys of the fragment
    values: list[torch.Tensor]  # values of this fragment
    frags: torch.Tensor  # fragment lengths
    cos: torch.Tensor  # cosines to apply to query
    sin: torch.Tensor  # sines to apply to query
    loc: torch.Tensor  # relative location


class HogwildCache(Cache):
    def __init__(
            self,
            cache_structure: List[List[Cache]],
            model,
            write_to: Optional[List[Cache]] = None,
    ):
        self.model = model.model
        self.cache_structure = cache_structure
        self.write_to = write_to if write_to else [cl[-1] for cl in cache_structure]
        self.cosines = []
        self.sines = []
        self.locations = []
        self.segments = []
        self.frags = []
        self.queries_buffer = None
        self.att_buffer = None

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        # TODO THIS IS WRONG IF WE DO MERGING
        return self.cache_structure[0][-1].get_seq_length()

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> CacheStructure:
        # update the worker caches
        assert key_states.shape[0] == len(self.cache_structure)

        # assume each batch index corresponds to one worker
        # TODO handle cases where not all workers are active
        for w in range(key_states.shape[0]):
            self.write_to[w].update(key_states[w:w + 1, ...], value_states[w:w + 1, ...], layer_idx, cache_kwargs)

        for cs in self.segments:
            cs.key_cache[layer_idx] = cs.key_cache[layer_idx].contiguous()
            cs.value_cache[layer_idx] = cs.value_cache[layer_idx].contiguous()

        if layer_idx == 0:
            mapping: Dict[int, InternalCacheMeta] = {}
            workers = len(self.cache_structure)

            for cs in self.cache_structure[0]:
                mapping[id(cs)] = InternalCacheMeta(
                    cos=[None] * workers, sin=[None] * workers, loc=[None] * workers, cs=cs)

            # and construct the info we need to actually run attention
            for w in range(key_states.shape[0]):
                pos = 0
                for cs in reversed(self.cache_structure[w]):
                    pos += cs.get_seq_length(layer_idx)
                    # at this point, pos already includes the newly-added tokens
                    # so, in order to match the right query position, we need to subtract the number of
                    # tokens currently being added
                    pos_t = torch.arange(pos - key_states.shape[2], pos, device=key_states.device, dtype=torch.int32)
                    mapping[id(cs)].loc[w] = pos_t

            # rearrange
            locations = []
            segments = []
            for entry in mapping.values():
                locations += entry.loc
                segments.append(entry.cs)

            locations = torch.stack(locations, dim=0)
            cosines, sines = self.model.rotary_emb(key_states, locations)
            self.cosines = cosines.reshape(len(segments), workers, locations.shape[1], cosines.shape[2]).to(torch.float)
            self.sines = sines.reshape(len(segments), workers, locations.shape[1], cosines.shape[2]).to(torch.float)
            self.locations = locations.reshape(len(segments), workers, locations.shape[1])
            self.segments = segments
            self.frags = torch.tensor([cs.get_seq_length(layer_idx) for cs in self.segments], dtype=torch.int32, device=self.cosines.device)
            # for some reason, having an explicit graph break is *essential* for good performance
            torch._dynamo.graph_break()
        keys = []
        vals = []
        for cs in self.segments:
            keys.append(cs.key_cache[layer_idx].contiguous())
            vals.append(cs.value_cache[layer_idx].contiguous())
        return CacheStructure(keys=keys, values=vals, cos=self.cosines, sin=self.sines, loc=self.locations, frags=self.frags)

    def get_queries_buffer(self, queries, layer_idx):
        if layer_idx == 0:
            self.queries_buffer = torch.empty((self.cosines.size(0), queries.size(0), queries.size(1),
                                               queries.size(2), queries.size(3)),
                                              dtype=queries.dtype, device=queries.device)
        return self.queries_buffer

    def get_att_buffer(self, r_queries, layer_idx):
        if layer_idx == 0:
            self.att_buffer = torch.empty((r_queries.size(1), r_queries.size(2), r_queries.size(3),
                                           self.cache_structure[0][0].value_cache[layer_idx].size(3)),
                                          dtype=r_queries.dtype, device=r_queries.device)
        return self.att_buffer


#############################################################################
# Model
#########
class AttentionModule(nn.Module):
    """Modified attention layer adapted to HogwildCache.
    """

    def __init__(self, config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim ** -0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=True
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
        )

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_embeddings: Tuple[torch.Tensor, torch.Tensor],
            attention_mask: Optional[torch.Tensor],
            past_key_value: Optional[HogwildCache] = None,
            cache_position: torch.LongTensor = None,
            **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # print("FORWARD")
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        # queries will be rotated for individual segments, so nothing to do here
        key_states = apply_rotary_pos_emb(key_states, cos, sin)

        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        cache: CacheStructure = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # concatenate queries in sequence length dimension
        # expand query per fragment
        rq=past_key_value.get_queries_buffer(query_states, layer_idx=self.layer_idx)
        attn_output = hogwild_fused(
            query_states,
            cache.loc,
            cache.keys,
            cache.values,
            self.scaling,
            cache.frags,
            cache.cos, cache.sin,
            rotated_queries=rq,
            out=past_key_value.get_att_buffer(rq, layer_idx=self.layer_idx),
        ).transpose(1, 2)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, None


def model_surgery(model):
    for l in model.model.layers:
        old = l.self_attn
        l.self_attn = AttentionModule(model.model.config, l.self_attn.layer_idx)
        l.self_attn.k_proj = old.k_proj
        l.self_attn.v_proj = old.v_proj
        l.self_attn.q_proj = old.q_proj
        l.self_attn.o_proj = old.o_proj
    model._update_causal_mask = lambda *arg, **kwargs: None
    model.model._update_causal_mask = lambda *arg, **kwargs: None