import torch
import triton
from torch import nn
from typing import Tuple
import triton.language as tl
import torch.nn.functional as F

from inference.torch.codebook import CodebookManager
from inference.torch.layers.encoder import CallableEncoder


@triton.jit
def hyper_embedding_forward_kernel(
    embeddings_ptr,
    hyper_embeddings_ptr,
    indices_ptr,
    output_ptr,
    n_elements: tl.constexpr,
    sequence_length: tl.constexpr,
    initial_vocab_size: tl.constexpr,
    max_codebook_size: tl.constexpr,
    embedding_dim: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    start_m = pid_m * BLOCK_SIZE_M
    start_n = pid_n * BLOCK_SIZE_N
    offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
    offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
    mask_m = offsets_m < n_elements
    mask_n = offsets_n < embedding_dim

    indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)

    hyper_indices = indices - initial_vocab_size
    # batch_offsets = (offsets_m // sequence_length) * max_codebook_size
    embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
    hyper_embedding_offsets = (
        (
            hyper_indices[:, None]  # + batch_offsets
        )
        * embedding_dim
        + offsets_n[None, :]
    )

    embeddings = tl.load(
        embeddings_ptr + embedding_offsets,
        mask=mask_m[:, None]
        & mask_n[None, :]
        & (indices < initial_vocab_size)[:, None],
        other=0.0,
    )

    hyper_embeddings = tl.load(
        hyper_embeddings_ptr + hyper_embedding_offsets,
        mask=mask_m[:, None]
        & mask_n[None, :]
        & (indices >= initial_vocab_size)[:, None],
        other=0.0,
    )

    output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
    tl.store(
        output_ptr + output_offsets,
        embeddings + hyper_embeddings,
        mask=mask_m[:, None] & mask_n[None, :],
    )


@triton.jit
def hyper_embedding_backward_kernel(
    grad_output_ptr,
    grad_weight_ptr,
    grad_hyper_weight_ptr,
    indices_ptr,
    n_elements,
    sequence_length,
    initial_vocab_size,
    max_codebook_size,
    embedding_dim: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    start_m = pid_m * BLOCK_SIZE_M
    start_n = pid_n * BLOCK_SIZE_N
    offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
    offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
    mask_m = offsets_m < n_elements
    mask_n = offsets_n < embedding_dim

    indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)

    grad_output = tl.load(
        grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
        mask=mask_m[:, None] & mask_n[None, :],
        other=0.0,
    )

    grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
    tl.atomic_add(
        grad_weight_ptr + grad_weight_offsets,
        grad_output.cast(tl.float32),
        mask=mask_m[:, None]
        & mask_n[None, :]
        & (indices < initial_vocab_size)[:, None],
    )

    # batch_offsets = (offsets_m // sequence_length) * max_codebook_size

    grad_hyper_weight_offsets = (
        (
            (indices - initial_vocab_size)[:, None]  # + batch_offsets[:, None]
        )
        * embedding_dim
        + offsets_n[None, :]
    )
    tl.atomic_add(
        grad_hyper_weight_ptr + grad_hyper_weight_offsets,
        grad_output.cast(tl.float32),
        mask=mask_m[:, None]
        & mask_n[None, :]
        & (indices >= initial_vocab_size)[:, None],
    )


class HyperEmbeddingFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        embeddings: torch.Tensor,
        hyper_embeddings: torch.Tensor,
        indices: torch.Tensor,
        initial_vocab_size: int,
        max_codebook_size: int,
    ) -> torch.Tensor:
        indices = indices.contiguous()  # (B, S)
        embeddings = embeddings.contiguous()  # (V, D)
        hyper_embeddings = hyper_embeddings.contiguous()  # (B, EV, D)

        n_elements = indices.numel()
        original_shape = indices.shape
        embedding_dim = embeddings.shape[1]
        sequence_length = original_shape[1]

        indices = indices.view(-1)
        hyper_embeddings = hyper_embeddings.view(-1, embedding_dim)

        output = torch.empty(
            indices.shape[0],
            embeddings.shape[1],
            device=indices.device,
            dtype=embeddings.dtype,
        )

        BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
        BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
        grid = (
            triton.cdiv(n_elements, BLOCK_SIZE_M),
            triton.cdiv(embedding_dim, BLOCK_SIZE_N),
        )

        hyper_embedding_forward_kernel[grid](
            embeddings,
            hyper_embeddings,
            indices,
            output,
            n_elements,
            sequence_length,
            initial_vocab_size,
            max_codebook_size,
            embedding_dim=embedding_dim,
            BLOCK_SIZE_M=BLOCK_SIZE_M,
            BLOCK_SIZE_N=BLOCK_SIZE_N,
        )

        ctx.batch_size = original_shape[0]
        ctx.sequence_length = original_shape[1]
        ctx.initial_vocab_size = initial_vocab_size
        ctx.max_codebook_size = max_codebook_size
        ctx.save_for_backward(indices, embeddings, hyper_embeddings)

        return output.view(*original_shape, -1)

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
        batch_size = ctx.batch_size
        sequence_length = ctx.sequence_length
        max_codebook_size = ctx.max_codebook_size
        initial_vocab_size = ctx.initial_vocab_size

        indices, embeddings, hyper_embeddings = ctx.saved_tensors
        grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])

        grad_weight = torch.zeros_like(embeddings, dtype=torch.float32)
        grad_hyper_weight = torch.zeros_like(hyper_embeddings, dtype=torch.float32)

        n_elements = indices.numel()
        embedding_dim = embeddings.shape[1]

        BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
        BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
        grid = (
            triton.cdiv(n_elements, BLOCK_SIZE_M),
            triton.cdiv(embedding_dim, BLOCK_SIZE_N),
        )

        hyper_embedding_backward_kernel[grid](
            grad_output,
            grad_weight,
            grad_hyper_weight,
            indices,
            n_elements,
            sequence_length,
            initial_vocab_size,
            max_codebook_size,
            embedding_dim=embedding_dim,
            BLOCK_SIZE_M=BLOCK_SIZE_M,
            BLOCK_SIZE_N=BLOCK_SIZE_N,
        )

        return (
            grad_weight.to(embeddings.dtype),
            grad_hyper_weight.to(hyper_embeddings.dtype).view(
                batch_size, max_codebook_size, embedding_dim
            ),
            None,
            None,
            None,
        )


class PreInitializedEmbedding(nn.Module):
    def __init__(
        self,
        initial_vocab_size: int,
        embedding_dim: int,
        weight: nn.Parameter,
        pad_token_id: int,
    ) -> None:
        super().__init__()
        self.initial_vocab_size = initial_vocab_size
        self.embedding_dim = embedding_dim
        self.weight = weight
        self.pad_token_id = pad_token_id

    def forward(self, input: torch.LongTensor) -> torch.Tensor:
        return F.embedding(input, self.weight, self.pad_token_id)


class HyperEmbedding(PreInitializedEmbedding):
    def __init__(
        self,
        initial_vocab_size: int,
        max_codebook_size: int,
        embedding_dim: int,
        weight: nn.Parameter,
        callable_encoder: CallableEncoder,
        codebook_manager: CodebookManager,
        pad_token_id: int,
    ) -> None:
        super().__init__(initial_vocab_size, embedding_dim, weight, pad_token_id)
        self.max_codebook_size = max_codebook_size

        self.callable_encoder = callable_encoder
        self.codebook_manager = codebook_manager

    def forward(self, input: torch.LongTensor) -> torch.Tensor:
        # TODO: we don't support batching for now
        hyper_embedding = self.codebook_manager.get_hyper_embedding(
            input[0], self.weight, self.callable_encoder
        )

        return HyperEmbeddingFunction.apply(
            self.weight,
            hyper_embedding.unsqueeze(0),
            input,
            self.initial_vocab_size,
            self.max_codebook_size,
        )

    @staticmethod
    def from_embedding(
        embedding: nn.Embedding,
        initial_vocab_size: int,
        max_codebook_size: int,
        callable_encoder: CallableEncoder,
        pad_token_id: int,
        codebook_manager: CodebookManager,
    ) -> "HyperEmbedding":
        return HyperEmbedding(
            initial_vocab_size,
            max_codebook_size,
            embedding.weight.shape[1],
            embedding.weight,
            callable_encoder,
            codebook_manager,
            pad_token_id,
        )
