import math
from typing import Tuple, List, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_func

from utils.utils import describe_tensor


# Taken from facebookresearch/llama/model.py
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))

    T_q = xq.shape[1]
    freqs_cis_q = reshape_for_broadcast(freqs_cis[:T_q], xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis_q).flatten(3)
    return xq_out.type_as(xq)


# Taken from facebookresearch/llama/model.py
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    return freqs_cis


class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim: int, bias: bool):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class KVCache:
    def __init__(self, B: int, T: int, emb_dim: int, num_heads: int, device: int):
        self.k = torch.zeros((B, num_heads, T, emb_dim // num_heads), device=device)
        self.v = torch.zeros((B, num_heads, T, emb_dim // num_heads), device=device)

    def reset(self):
        self.k.zero_()
        self.v.zero_()

    def update(self, pos: int, k: torch.Tensor, v: torch.Tensor):
        self.k[:, :, pos : pos + 1, :] = k
        self.v[:, :, pos : pos + 1, :] = v

    def assign(self, k: torch.Tensor, v: torch.Tensor):
        self.k = k
        self.v = v


def clone_cache(cache: KVCache):
    res = KVCache(1, 1, 1, 1, "cpu")
    res.assign(cache.k.clone(), cache.v.clone())
    return res


def create_attn_mask(block_size, causality):
    lookbehind, lookahead = causality
    if lookahead is None:
        lookahead = block_size
    if lookbehind is None:
        lookbehind = block_size
    i = torch.arange(block_size).view(-1, 1)
    j = torch.arange(block_size).view(1, -1)
    mask = (j >= (i - lookbehind)) & (j <= (i + lookahead))
    return mask


# (lookbehind, lookahead)
Causality = Tuple[Optional[int], Optional[int]]


class MultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causality: Causality,
        disable_flash: bool = False,
    ):
        super().__init__()
        assert embed_dim % n_head == 0
        self.n_head = n_head
        self.n_embd = embed_dim
        self.dropout = dropout
        self.causality = causality

        # key, query, value projections for all heads, but in a batch
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # output projection
        self.output_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = (
            hasattr(torch.nn.functional, "scaled_dot_product_attention")
            and not disable_flash
        )
        if not self.flash:
            print(
                "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
            )

        self.register_buffer(
            "attn_mask",
            create_attn_mask(block_size, causality),
            persistent=False
        )
        assert self.attn_mask.dtype == torch.bool

    def dense_attention(self, q, k, v, force_noncausal):
        # Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            if force_noncausal:
                attn_mask = None
            else:
                attn_mask = self.attn_mask[:q.shape[2], :k.shape[2]]
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=attn_mask,
                dropout_p=self.dropout if self.training else 0,
            )
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            if not force_noncausal:
                att = att.masked_fill(self.attn_mask[None, None, :att.shape[2], :att.shape[3]] == 0, float("-inf"))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        return y

    def window_attention(self, q, k, v):
        # Need shape: (B, T, nh, hs)
        dtype = q.dtype
        q = q.transpose(1, 2).to(torch.bfloat16)
        k = k.transpose(1, 2).to(torch.bfloat16)
        v = v.transpose(1, 2).to(torch.bfloat16)
        y = flash_attn_func(q, k, v, window_size=self.causality)
        return y.transpose(1, 2).to(dtype)

    def attention(self, q, k, v, force_noncausal=False):
        if force_noncausal or self.causality[0] is None or self.causality[1] is None:
            y = self.dense_attention(q, k, v, force_noncausal)
        else:
            y = self.window_attention(q, k, v)

        y = rearrange(y, "B nh T hs -> B T (nh hs)")
        y = self.resid_dropout(self.output_proj(y))
        return y

    def forward(
        self, x: torch.Tensor, mem: torch.Tensor, freqs_cis: torch.Tensor
    ) -> torch.Tensor:
        B, T_x, C = x.size()
        B, T_mem, C = mem.size()

        q = self.q_proj(x).view(B, T_x, self.n_head, C // self.n_head)
        k = self.k_proj(mem).view(B, T_mem, self.n_head, C // self.n_head)
        v = self.v_proj(mem).view(B, T_mem, self.n_head, C // self.n_head)

        k = apply_rotary_emb(k, freqs_cis)
        q = apply_rotary_emb(q, freqs_cis)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        return self.attention(q, k, v)

    def get_start_pos(self, pos: int):
        lookbehind = self.causality[0]
        return 0 if lookbehind is None else max(0, pos - lookbehind)

    def get_end_pos(self, pos: int, kv_len: int):
        lookahead = self.causality[1]
        return kv_len - 1 if lookahead is None else min(kv_len - 1, pos + lookahead)

    # x: (B, 1, emb_dim)
    def kv_inference(self, x: torch.Tensor, pos: int, freqs_cis: torch.Tensor, kv_cache: KVCache):
        start_pos = self.get_start_pos(pos)
        end_pos = self.get_end_pos(pos, kv_cache.k.shape[2])
        B, T_x, C = x.size()
        assert T_x == 1
        q = self.q_proj(x).view(B, 1, self.n_head, C // self.n_head)
        q = apply_rotary_emb(q, freqs_cis[pos: pos + 1])
        q = q.transpose(1, 2)
        return self.attention(q, kv_cache.k[:, :, start_pos:end_pos + 1, :], kv_cache.v[:, :, start_pos:end_pos + 1, :], force_noncausal=True)

    def kv_inference_self_attn(self, x: torch.Tensor, pos: int, freqs_cis: torch.Tensor, kv_cache: KVCache):
        assert self.causality[1] == 0
        start_pos = self.get_start_pos(pos)
        B, T_x, C = x.size()
        assert T_x == 1
        q = self.q_proj(x).view(B, 1, self.n_head, C // self.n_head)
        k = self.k_proj(x).view(B, 1, self.n_head, C // self.n_head)
        v = self.v_proj(x).view(B, 1, self.n_head, C // self.n_head)
        q = apply_rotary_emb(q, freqs_cis[pos: pos + 1])
        k = apply_rotary_emb(k, freqs_cis[pos: pos + 1])
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        kv_cache.update(pos, k, v)
        return self.attention(q, kv_cache.k[:, :, start_pos:pos+1, :], kv_cache.v[:, :, start_pos:pos+1, :], force_noncausal=True)

    def fill_cache(self, mem: torch.Tensor, freqs_cis: torch.Tensor, kv_cache: KVCache):
        B, T_x, C = mem.shape
        k = self.k_proj(mem).view(B, T_x, self.n_head, C // self.n_head)
        v = self.v_proj(mem).view(B, T_x, self.n_head, C // self.n_head)
        k = apply_rotary_emb(k, freqs_cis)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        kv_cache.assign(k, v)


class MLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        bias: bool,
        dropout: float,
    ):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, 4 * embed_dim, bias=bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * embed_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class EncoderBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causaliy: Causality,
    ):
        super().__init__()
        self.ln_1 = LayerNorm(embed_dim, bias=bias)
        self.attn = MultiheadAttention(
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causality=causaliy,
        )
        self.ln_2 = LayerNorm(embed_dim, bias=bias)
        self.mlp = MLP(
            embed_dim=embed_dim,
            bias=bias,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        x_norm = self.ln_1(x)
        x = x + self.attn(x_norm, x_norm, freqs_cis)
        x = x + self.mlp(self.ln_2(x))
        return x


class Encoder(nn.Module):
    def __init__(
        self,
        n_layers: int,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causality: Causality,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                EncoderBlock(
                    embed_dim=embed_dim,
                    n_head=n_head,
                    bias=bias,
                    dropout=dropout,
                    block_size=block_size,
                    causaliy=causality,
                )
                for _ in range(n_layers)
            ]
        )

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, freqs_cis)
        return x


class DecoderBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causality: Causality,
        cross_causality: Causality
    ):
        super().__init__()
        self.n_head = n_head
        self.ln_1 = LayerNorm(embed_dim, bias=bias)
        self.attn_1 = MultiheadAttention(
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causality=causality,
        )
        self.ln_2 = LayerNorm(embed_dim, bias=bias)
        self.attn_2 = MultiheadAttention(
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causality=cross_causality,
        )
        self.ln_3 = LayerNorm(embed_dim, bias=bias)
        self.mlp = MLP(
            embed_dim=embed_dim,
            bias=bias,
            dropout=dropout,
        )

    def forward(
        self, x: torch.Tensor, mem: torch.Tensor, freqs_cis: torch.Tensor
    ) -> torch.Tensor:
        x_norm = self.ln_1(x)
        x = x + self.attn_1(x_norm, x_norm, freqs_cis)
        x = x + self.attn_2(self.ln_2(x), mem, freqs_cis)
        x = x + self.mlp(self.ln_3(x))
        return x

    def get_caches(self, mem: torch.Tensor, seq_len: int, freqs_cis: torch.Tensor):
        B, T, emb_dim = mem.shape
        self_cache = KVCache(B, seq_len, emb_dim, self.n_head, mem.device)
        mem_cache = KVCache(B, T, emb_dim, self.n_head, mem.device)
        self.attn_2.fill_cache(mem, freqs_cis, mem_cache)
        return self_cache, mem_cache

    def inference(self, x: torch.Tensor, pos: int, freqs_cis: torch.Tensor, caches: Tuple[KVCache, KVCache]):
        self_cache, mem_cache = caches
        x_norm = self.ln_1(x)
        x = x + self.attn_1.kv_inference_self_attn(x_norm, pos, freqs_cis, self_cache)
        x = x + self.attn_2.kv_inference(self.ln_2(x), pos, freqs_cis, mem_cache)
        x = x + self.mlp(self.ln_3(x))
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        output_dim: int,
        n_layers: int,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causality: Causality,
        cross_causality: Causality,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_dim=embed_dim,
                    n_head=n_head,
                    bias=bias,
                    dropout=dropout,
                    block_size=block_size,
                    causality=causality,
                    cross_causality=cross_causality
                )
                for _ in range(n_layers)
            ]
        )
        self.final_layer_norm = LayerNorm(embed_dim, bias=bias)
        self.output_projection = nn.Linear(embed_dim, output_dim)

    def forward(
        self, x: torch.Tensor, mem: torch.Tensor, freqs_cis: torch.Tensor
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mem, freqs_cis)
        x = self.output_projection(self.final_layer_norm(x))
        return x

    def get_caches(
            self, mem: torch.Tensor, seq_len: int, freqs_cis: torch.Tensor, beam_k: int
    ) -> List[List[Tuple[KVCache, KVCache]]]:
        caches_per_beam = [
            [layer.get_caches(mem, seq_len, freqs_cis) for _ in range(beam_k)]
            for layer in self.layers
        ]
        caches_per_layer = list(zip(*caches_per_beam))
        return [list(beam_caches) for beam_caches in caches_per_layer]

    def inference(
        self, x: torch.Tensor, pos: int, freqs_cis: torch.Tensor, caches: List[Tuple[KVCache, KVCache]]
    ) -> torch.Tensor:
        for layer, cache in zip(self.layers, caches):
            x = layer.inference(x, pos, freqs_cis, cache)
        x = self.output_projection(self.final_layer_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        n_encoder_layers: int,
        n_decoder_layers: int,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        encoder_causality: Causality,
        decoder_causality: Causality,
        cross_causality: Causality,
        max_seq_len: int,
        quantized_io = False,
        quantized_cond = False,
        input_dim = None,
        output_dim = None,
        input_vocab_size = None,
        cond_dim = None,
        cond_vocab_size = None,
        share_input_cond_emb = False,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.n_head = n_head
        self.bias = bias
        self.dropout = dropout
        self.block_size = block_size
        self.max_seq_len = max_seq_len

        self.quantized_io = quantized_io
        self.quantized_cond = quantized_cond
        if quantized_io:
            self.input_embeddings = nn.Embedding(num_embeddings=input_vocab_size + 1, embedding_dim=embed_dim)
            self.start_token = input_vocab_size
            output_dim = input_vocab_size
        else:
            self.input_projection = nn.Linear(input_dim, embed_dim)
            self.register_parameter(
                "start_token", nn.Parameter(torch.randn(1, 1, embed_dim))
            )
        if quantized_cond:
            if share_input_cond_emb:
                assert input_vocab_size == cond_vocab_size
                self.cond_embeddings = self.input_embeddings
            else:
                self.cond_embeddings = nn.Embedding(num_embeddings=cond_vocab_size, embedding_dim=embed_dim)
        else:
            if share_input_cond_emb:
                self.cond_projection = self.input_projection
            else:
                self.cond_projection = nn.Linear(cond_dim, embed_dim)

        self.encoder = Encoder(
            n_layers=n_encoder_layers,
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causality=encoder_causality,
        )

        self.decoder = Decoder(
            output_dim=output_dim,
            n_layers=n_decoder_layers,
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causality=decoder_causality,
            cross_causality=cross_causality,
        )
        self.freqs_cis = precompute_freqs_cis(embed_dim // n_head, max_seq_len)

        # Init all weights
        self.apply(self._init_weights)
        # Apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("output_proj.weight"):
                torch.nn.init.normal_(
                    p,
                    mean=0.0,
                    std=0.02 / math.sqrt(n_encoder_layers + n_decoder_layers),
                )

    @property
    def num_params(self) -> int:
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def right_shift_input(self, input: torch.Tensor) -> torch.Tensor:
        start_token = self.start_token.expand(input.shape[0], -1, -1)
        return torch.concat([start_token, input[:, :-1, :]], dim=1)

    def embed_input(self, input: torch.Tensor) -> torch.Tensor:
        if self.quantized_io:
            return self.input_embeddings(input)
        else:
            return self.input_projection(input)

    def embed_cond(self, cond: torch.Tensor) -> torch.Tensor:
        if self.quantized_cond:
            return self.cond_embeddings(cond)
        else:
            return self.cond_projection(cond)

    def forward(
            self, input: torch.Tensor, cond: torch.Tensor, start_pos: int = 0, apply_start_token: bool = False
    ) -> torch.Tensor:
        input = self.embed_input(input)
        if apply_start_token:
            input[:, 0] = self.start_token
        _, t, _ = input.shape
        assert t <= self.block_size, (
            f"Cannot forward sequence of length {t}, "
            f"block size is only {self.block_size}"
        )

        freqs_cis = self.freqs_cis.to(input.device)
        freqs_cis = freqs_cis[start_pos : start_pos + t]

        # Encode the conditioning information
        cond = self.embed_cond(cond)
        cond_feature = self.encoder(cond, freqs_cis)
        # Get the decoded features
        pred = self.decoder(input, cond_feature, freqs_cis)

        return pred

    def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float:
        """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        L, H, Q, T = (
            self.n_encoder_layers + self.n_decoder_layers,
            self.n_head,
            self.embed_dim // self.n_head,
            self.block_size,
        )
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0 / dt)  # per second
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate_ref(
        self,
        cond: torch.Tensor,
        window_size: int,
        context_size: int,
        start_pos: int = 0,
    ) -> torch.Tensor:
        assert not self.quantized_io, "Quantized representations are not supported in legacy generation"
            
        T = cond.shape[1]

        freqs_cis = self.freqs_cis.to(cond.device)
        freqs_cis = freqs_cis[start_pos : start_pos + T]

        # Encode the conditioning information
        cond = self.embed_cond(cond)
        cond_feature = self.encoder(cond, freqs_cis)

        input_feature = self.start_token.expand(cond.shape[0], -1, -1)
        recons = [torch.zeros(cond.shape[0], 2, window_size, device=cond.device)]

        for t in range(T):
            recon = self.decoder(input_feature, cond_feature, freqs_cis)

            # Add the new reconstruction
            recons.append(recon[:, -1, :].reshape(-1, 2, window_size))

            feature_next = torch.concat(
                [recons[-2][:, :, -context_size:], recons[-1]], dim=-1
            ).reshape(cond.shape[0], 1, -1)
            input_feature = torch.concat(
                [input_feature, self.input_projection(feature_next)], dim=1
            )

        return torch.stack(recons[1:], dim=1).flatten(start_dim=2)

    def generate_non_quantized(
        self,
        cond: torch.Tensor,
        window_size: int,
        context_size: int,
        caches: List[List[Tuple[KVCache, KVCache]]],
        freqs_cis: torch.Tensor
    ) -> torch.Tensor:
        B, T, _ = cond.shape
        cur_token = self.start_token.expand(B, -1, -1) # (B, 1, emb_dim)
        recons = [torch.zeros(cond.shape[0], 2, window_size, device=cond.device)]
        for t in range(T):
            recon = self.decoder.inference(cur_token, t, freqs_cis, caches)
            recons.append(recon[:, -1, :].reshape(-1, 2, window_size))
            feature_next = torch.concat(
                [recons[-2][:, :, -context_size:], recons[-1]], dim=-1
            ).reshape(cond.shape[0], 1, -1)
            cur_token = self.embed_input(feature_next)
        return torch.stack(recons[1:], dim=1).flatten(start_dim=2)

    def assign_kv_cache(self,
                        kv1: List[Tuple[KVCache, KVCache]],
                        kv2: List[Tuple[KVCache, KVCache]],
                        b: int):
        """
        We have per-layer KV caches kv1 and kv2. We want to assign all values of kv2
        for b-th element (in batch dimension) to corresponding values in kv1
        """
        for layer_idx in range(len(kv1)):
            self1, mem1 = kv1[layer_idx]
            self2, mem2 = kv2[layer_idx]
            self2.k[b] = self1.k[b]
            self2.v[b] = self1.v[b]
            mem2.k[b] = mem1.k[b]
            mem2.v[b] = mem2.v[b]


    def generate_quantized_old(
            self,
            cond,
            caches,
            freqs_cis):
        B, T, _ = cond.shape
        cur_token = self.embed_input(torch.full((B, 1), self.start_token, device=cond.device))
        out = []
        for t in range(T):
            logits = self.decoder.inference(cur_token, t, freqs_cis, caches)
            preds = torch.argmax(logits, dim=2)
            out.append(preds)
            cur_token = self.embed_input(preds)
        return torch.concat(out, dim=1)

    def generate_quantized(
            self,
            cond: torch.Tensor,
            caches: List[List[Tuple[KVCache, KVCache]]],
            freqs_cis: torch.Tensor,
            beam_k: int = 1
    ) -> torch.Tensor:
        B, T, _ = cond.shape
        device = cond.device
        sequences = torch.full((1, B, 1), self.start_token, device=device)  # (beam_cur_k=1, B, seq_len=1)
        cur_token = self.embed_input(sequences)  # (beam_cur_k=1, B, seq_len=1, emb_dim)
        scores = torch.zeros(B, 1, device=device)  # (B, beam_k_cur=1)
        new_caches = [[(clone_cache(s), clone_cache(m)) for s, m in lst] for lst in caches]
        import time
        gpu_time = 0.0
        for t in range(T):
            B_cur, k = scores.shape  # (B, k=beam_k_cur)

            best: List[List[Tuple[float, int, int]]] = [[] for _ in range(B)]  # best[i] will contain (new_score, beam_seq, new_token) tuples
            for i in range(k):
                gpu_time -= time.time()
                cur_logits = self.decoder.inference(cur_token[i], t, freqs_cis, caches[i]).squeeze(1)  # (B, vocab)
                gpu_time += time.time()
                log_probs = torch.log_softmax(cur_logits, dim=1)  # (B, vocab)
                rem = min(log_probs.shape[1], beam_k)
                top_probs, top_idx = torch.topk(log_probs, rem, sorted=True, dim=1)  # (B, rem)
                new_scores = scores[:, i:i+1] + top_probs  # (B, rem)
                for b in range(B):
                    for j in range(rem):
                        best[b].append((new_scores[b, j].item(), i, top_idx[b, j].item()))
            
            nk = min(beam_k, len(best[0]))
            new_token = torch.zeros((nk, B), device=device, dtype=sequences.dtype)
            new_scores = torch.zeros((B, nk), device=device)
            new_sequences = torch.zeros((nk, B, t+2), device=device, dtype=sequences.dtype)
            for b in range(B):
                best[b] = sorted(best[b], reverse=True)
                for i in range(nk):
                    new_score, beam_seq, token = best[b][i]
                    new_token[i, b] = token
                    new_scores[b, i] = new_score
                    new_sequences[i, b, :t+1] = sequences[beam_seq, b, :]
                    new_sequences[i, b, t+1] = token
                    self.assign_kv_cache(caches[beam_seq], new_caches[i], b)

            caches, new_caches = new_caches, caches
            sequences = new_sequences
            scores = new_scores
            cur_token = self.embed_input(new_token.unsqueeze(2))
        print(gpu_time)
        return sequences[0, :, 1:]  # beam dim is sorted by score, ignore start token

    @torch.no_grad()
    def generate(
        self,
        cond: torch.Tensor,
        window_size: Optional[int] = None,
        context_size: Optional[int] = None,
        start_pos: int = 0,
        beam_k: int = 1,
    ) -> torch.Tensor:
        cond = self.embed_cond(cond)
        B, T, _ = cond.shape

        freqs_cis = self.freqs_cis.to(cond.device)
        freqs_cis = freqs_cis[start_pos : start_pos + T]
        cond_feature = self.encoder(cond, freqs_cis)

        caches = self.decoder.get_caches(cond_feature, T, freqs_cis, beam_k)
        if self.quantized_io:
            if beam_k == 1:
                return self.generate_quantized_old(cond, caches[0], freqs_cis)
            else:
                return self.generate_quantized(cond, caches, freqs_cis, beam_k)
        else:
            assert window_size is not None
            assert context_size is not None
            assert beam_k == 1, "Beam search is not supported for non-quantized generation"
            return self.generate_non_quantized(cond, window_size, context_size, caches[0], freqs_cis)


def test_generation_consistency():
    print("test_generation_consistency")
    mini_transformer = Transformer(
        n_encoder_layers=2,
        n_decoder_layers=2,
        embed_dim=8,
        n_head=2,
        bias=True,
        dropout=0.0,
        block_size=10,
        encoder_causality=(3, 1),
        decoder_causality=(3, 0),
        cross_causality=(3, 1),
        max_seq_len=10,
        input_dim=6,
        cond_dim=6,
        output_dim=4)

    mixture = torch.randn((2, 8, 6))
    input = torch.randn((2, 8, 6))
    describe_tensor(mixture, "mixture")
    describe_tensor(input, "input")
    _ = mini_transformer(input, mixture)
    res1 = mini_transformer.generate_ref(mixture, window_size=2, context_size=1)
    res2 = mini_transformer.generate(mixture, window_size=2, context_size=1)
    describe_tensor(res1, "res1")
    describe_tensor(res2, "res2")
    mse = ((res1 - res2) ** 2).mean()
    print("mse", mse)


def test_quantized_io():
    print("test_quantized_io")
    mini_transformer = Transformer(
        n_encoder_layers=2,
        n_decoder_layers=2,
        embed_dim=8,
        n_head=2,
        bias=True,
        dropout=0.0,
        block_size=10,
        encoder_causality=(3, 1),
        decoder_causality=(3, 0),
        cross_causality=(3, 1),
        max_seq_len=10,
        quantized_io=True,
        input_vocab_size=7,
        cond_dim=6)
    mixture = torch.randn((3, 5, 6))
    input = torch.randint(7, (3, 5))
    describe_tensor(mixture, "mixture")
    describe_tensor(input, "input")
    output = mini_transformer(input, mixture)
    describe_tensor(output, "output")
    gen_out = mini_transformer.generate(mixture, beam_k=2)
    describe_tensor(gen_out, "gen_out")
    print(gen_out)


if __name__ == "__main__":
    torch.set_printoptions(precision=6, sci_mode=False)
    torch.manual_seed(48)
    test_generation_consistency()
    # test_quantized_io()
