import os
import time
from typing import Optional

import triton
import triton.language as tl
import torch
from torch import Tensor
import torch.autograd


@triton.jit
def memory_efficient_llm_ce_cuda(
        HIDDEN, stride_hidden_n, stride_hidden_hid,
        PROJ, stride_proj_kout, stride_proj_kin,
        LABEL, stride_label_n,
        LOGIT_MULTIPLIER,

        LOSS, stride_loss_n,

        N, HID, KOUT,

        BLOCK_N: tl.constexpr,
        BLOCK_HID: tl.constexpr,
        BLOCK_KOUT: tl.constexpr,
):
    idx_n = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_n = idx_n < N

    target_label = tl.load(
        LABEL +
        idx_n * stride_label_n,
        mask=mask_n,
    )

    l_i = tl.full((BLOCK_N,), value=1.0, dtype=tl.float32)
    m_i = tl.full((BLOCK_N,), value=float('-inf'), dtype=tl.float32)
    for idx_blabel in range(tl.cdiv(KOUT, BLOCK_KOUT)):
        idx_label = tl.arange(0, BLOCK_KOUT) + idx_blabel * BLOCK_KOUT
        mask_label = idx_label < KOUT

        acc = tl.zeros((BLOCK_N, BLOCK_KOUT), dtype=tl.float32)
        for idx_bhid in range(0, tl.cdiv(HID, BLOCK_HID)):
            idx_hid = tl.arange(0, BLOCK_HID) + BLOCK_HID * idx_bhid
            assert (HID % BLOCK_HID) == 0
            hidden = tl.load(
                HIDDEN +
                idx_n[:, None] * stride_hidden_n +
                idx_hid[None, :] * stride_hidden_hid,
                mask=mask_n[:, None],
                other=0
            )
            proj = tl.load(
                PROJ +
                idx_label[None, :] * stride_proj_kout +
                idx_hid[:, None] * stride_proj_kin,
                mask=mask_label[None, :],
                other=0
            )
            acc += tl.dot(
                hidden,
                proj.to(hidden.dtype),
                allow_tf32=True,
            ).to(acc.dtype)

        acc = acc * LOGIT_MULTIPLIER.to(acc.dtype)

        m_last = m_i
        m_i = tl.maximum(m_i, tl.max(acc, axis=1))
        P_tilde_i = tl.exp(acc - m_i[:, None])
        exp_m_diff = tl.exp(m_last - m_i)
        l_i = l_i * exp_m_diff + tl.sum(P_tilde_i, axis=1)

    L_i = m_i + tl.log(l_i)

    loss = tl.zeros((BLOCK_N,), dtype=tl.float32)
    for idx_blabel in range(tl.cdiv(KOUT, BLOCK_KOUT)):
        idx_label = tl.arange(0, BLOCK_KOUT) + idx_blabel * BLOCK_KOUT
        mask_label = idx_label < KOUT

        acc = tl.zeros((BLOCK_N, BLOCK_KOUT), dtype=tl.float32)
        for idx_bhid in range(0, tl.cdiv(HID, BLOCK_HID)):
            idx_hid = tl.arange(0, BLOCK_HID) + BLOCK_HID * idx_bhid
            assert (HID % BLOCK_HID) == 0
            hidden = tl.load(
                HIDDEN +
                idx_n[:, None] * stride_hidden_n +
                idx_hid[None, :] * stride_hidden_hid,
                mask=mask_n[:, None],
                other=0
            )
            proj = tl.load(
                PROJ +
                idx_label[None, :] * stride_proj_kout +
                idx_hid[:, None] * stride_proj_kin,
                mask=mask_label[None, :],
                other=0
            )
            acc += tl.dot(
                hidden,
                proj.to(hidden.dtype),
                allow_tf32=True,
            ).to(acc.dtype)

        acc = acc * LOGIT_MULTIPLIER.to(acc.dtype)

        P_ij = L_i[:, None] - acc
        loss += tl.sum(tl.where(
            target_label[:, None] == idx_label[None, :],
            P_ij, 0,
        ), axis=1)

    loss = tl.where(target_label >= 0, loss, float('nan'))

    tl.store(
        LOSS +
        idx_n * stride_loss_n,
        mask=mask_n,
        value=loss
    )


def memory_efficient_llm_ce_fwd(
        hidden_states: Tensor,
        out_proj_weight: Tensor,
        labels: Tensor,
        logit_multiplier: float,
):
    assert hidden_states.ndim == 2, f"{hidden_states.shape}"
    assert out_proj_weight.ndim == 2, f"{out_proj_weight.shape}"
    assert labels.ndim == 1, f"{labels.shape}"
    assert labels.dtype in [torch.int32, torch.int64, torch.long]
    assert hidden_states.device == out_proj_weight.device, f"{hidden_states.device} != {out_proj_weight.device}"
    assert labels.device == hidden_states.device, f"{labels.device} != {hidden_states.device}"
    N, HID = hidden_states.shape
    KOUT, KIN = out_proj_weight.shape
    _N, = labels.shape
    assert N == _N, f'{N} == {_N}'
    assert HID == KIN

    losses = torch.empty((N,), dtype=torch.float32, device=hidden_states.device)

    BLOCK_N = 128
    BLOCK_HID = 32  # FIXME: revert back to 128
    BLOCK_KOUT = 32  # FIXME: revert back to 128

    assert (HID % BLOCK_HID) == 0

    grid = (triton.cdiv(N, BLOCK_N),)

    pre_device = torch.get_default_device()
    torch.set_default_device(hidden_states.device)
    memory_efficient_llm_ce_cuda[grid](
        hidden_states, *hidden_states.stride(),
        out_proj_weight, *out_proj_weight.stride(),
        labels, *labels.stride(),
        logit_multiplier,

        losses, *losses.stride(),

        N, HID, KOUT,

        BLOCK_N,
        BLOCK_HID,
        BLOCK_KOUT,

        num_warps=16,
    )
    torch.set_default_device(pre_device)

    return losses


@triton.jit
def memory_efficient_llm_ce_bwd_cuda(
        h_buffer, h_stride_n, h_stride_h,
        w_buffer, w_stride_d, w_stride_h,
        t_buffer, t_stride_n,
        logit_multiplier,
        do_buffer, do_stride_n,
        da_buffer, da_stride_n, da_stride_d,
        N, H, D,
        N_BLOCK_SIZE: tl.constexpr, H_BLOCK_SIZE: tl.constexpr, D_BLOCK_SIZE: tl.constexpr):

    n = tl.program_id(0) * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)

    do = tl.load(do_buffer + n * do_stride_n, mask=n < N, other=0.0)
    T = tl.load(t_buffer + n * t_stride_n, mask=n < N, other=-1)  # target labels

    # Step 1: compute logsumexp(A)
    A_max = tl.full((N_BLOCK_SIZE,), float('-inf'), dtype=tl.float32)
    L = tl.zeros((N_BLOCK_SIZE,), dtype=tl.float32)

    for d_begin in range(tl.cdiv(D, D_BLOCK_SIZE)):
        d = d_begin * D_BLOCK_SIZE + tl.arange(0, D_BLOCK_SIZE)

        A_part = tl.zeros((N_BLOCK_SIZE, D_BLOCK_SIZE), dtype=tl.float32)

        for h_begin in range(tl.cdiv(H, H_BLOCK_SIZE)):
            h = h_begin * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)

            W = tl.load(
                w_buffer + h[:, None] * w_stride_h + d[None, :] * w_stride_d,
                mask=(h[:, None] < H) & (d[None, :] < D),
                other=0
            )

            X = tl.load(
                h_buffer + n[:, None] * h_stride_n + h[None, :] * h_stride_h,
                mask=(n[:, None] < N) & (h[None, :] < H),
                other=0
            )

            A_part += tl.dot(X, W)

        A_max_prev = A_max
        A_max = tl.maximum(tl.max(A_part, axis=1), A_max)
        L = L * tl.exp(A_max_prev - A_max) + tl.sum(tl.exp(A_part - A_max[:, None]), axis=1)

    L = tl.log(L) + A_max

    # Step 2: compute dX = (softmax(A) - 1_t) * W, dW = X^T * (softmax(A) - 1_t)
    for d_begin in range(tl.cdiv(D, D_BLOCK_SIZE)):
        d = d_begin * D_BLOCK_SIZE + tl.arange(0, D_BLOCK_SIZE)

        A_part = tl.zeros((N_BLOCK_SIZE, D_BLOCK_SIZE), dtype=tl.float32)

        for h_begin in range(tl.cdiv(H, H_BLOCK_SIZE)):
            h = h_begin * H_BLOCK_SIZE + tl.arange(0, H_BLOCK_SIZE)

            W = tl.load(
                w_buffer + h[:, None] * w_stride_h + d[None, :] * w_stride_d,
                mask=(h[:, None] < H) & (d[None, :] < D),
                other=0
            )

            X = tl.load(
                h_buffer + n[:, None] * h_stride_n + h[None, :] * h_stride_h,
                mask=(n[:, None] < N) & (h[None, :] < H),
                other=0
            )

            A_part += tl.dot(X, W)

        S = tl.exp(A_part - L[:, None])  # (N_BLOCK_SIZE, D_BLOCK_SIZE)
        one_t = tl.where(T[:, None] == d, 1.0, 0.0)  # (N_BLOCK_SIZE, D_BLOCK_SIZE)
        dA = (S - one_t) * do[:, None]  # (N_BLOCK_SIZE, D_BLOCK_SIZE)
        dA = tl.where((n[:, None] < N) & (d[None, :] < D), dA, 0.0)
        dA = tl.where(T[:, None] >= 0, dA, 0.0)

        tl.store(
            da_buffer + n[:, None] * da_stride_n + d[None, :] * da_stride_d,
            dA,
            mask=(n[:, None] < N) & (d[None, :] < D)
        )


@torch.no_grad()
def memory_efficient_llm_ce_bwd(
        hidden_states: Tensor,
        out_proj_weight: Tensor,
        labels: Tensor,
        logit_multiplier: float,
        grad_output: Tensor):

    N, HID = hidden_states.shape
    KOUT, _ = out_proj_weight.shape

    BIG_BLOCK_N = 4096
    BLOCK_N = 128
    BLOCK_HID = os.environ.get("CE_BLOCK_HID", 128)
    BLOCK_KOUT = os.environ.get("CE_BLOCK_KOUT", 128)

    assert (HID % BLOCK_HID) == 0

    hidden_states = hidden_states * logit_multiplier

    with torch.cuda.device(hidden_states.device):

        dA = torch.zeros((BIG_BLOCK_N, KOUT), dtype=torch.float32, device=hidden_states.device)
        grad_hidden = torch.zeros_like(hidden_states)
        grad_weight = torch.zeros_like(out_proj_weight)

        for n_begin in range(0, N, BIG_BLOCK_N):
            n_end = min(n_begin + BIG_BLOCK_N, N)

            h = hidden_states[n_begin:n_end]
            t = labels[n_begin:n_end]
            da = dA[:n_end - n_begin]

            grid = (triton.cdiv(n_end - n_begin, BLOCK_N),)
            memory_efficient_llm_ce_bwd_cuda[grid](
                h, *h.stride(),
                out_proj_weight, *out_proj_weight.stride(),
                t, *t.stride(),
                logit_multiplier,

                grad_output, *grad_output.stride(),
                da, *da.stride(),

                N, HID, KOUT,

                BLOCK_N,
                BLOCK_HID,
                BLOCK_KOUT,

                num_warps=16,
            )

            da = da.to(hidden_states.dtype)
            torch.mm(da, out_proj_weight, out=grad_hidden[n_begin:n_end])
            torch.addmm(grad_weight, da.T, h, out=grad_weight)

    grad_hidden = grad_hidden * logit_multiplier

    return grad_hidden, grad_weight


class MemoryEfficientLLMCE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, hidden_states, out_proj_weight, labels, logit_multiplier: float):  # noqa
        loss = memory_efficient_llm_ce_fwd(hidden_states, out_proj_weight, labels, logit_multiplier)
        ctx.save_for_backward(hidden_states, out_proj_weight, labels)
        ctx.logit_multiplier = logit_multiplier
        return loss

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        hidden_states, out_proj_weight, labels = ctx.saved_tensors
        grad_hidden, grad_weight = memory_efficient_llm_ce_bwd(
            hidden_states, out_proj_weight, labels, ctx.logit_multiplier, grad_output)
        return grad_hidden, grad_weight, None, None


def memory_efficient_llm_ce(
        hidden_states: Tensor,
        out_proj_weight: Tensor,
        labels: Tensor,
        logit_multiplier: Optional[float],
        reduction: str = 'mean',
        threshold: int = 8192):

    if hidden_states.shape[0] <= threshold:
        logits = torch.nn.functional.linear(hidden_states, out_proj_weight)
        if logit_multiplier is not None:
            logits = logits * logit_multiplier
        return torch.nn.functional.cross_entropy(
            logits, labels,
            reduction=reduction,
        )

    losses = MemoryEfficientLLMCE.apply(
        hidden_states,
        out_proj_weight,
        labels,
        1.0 if logit_multiplier is None else logit_multiplier,
    )

    if reduction == 'mean':
        loss = losses.nanmean()
    elif reduction == 'sum':
        loss = losses.nansum()
    elif reduction == 'none':
        loss = losses
    else:
        raise Exception()

    return loss
