from typing import Optional

import torch
import torch.nn as nn

import liger_kernel.transformers.functional as F


def fixed_fused_linear_cross_entropy(
    hidden_states: torch.Tensor,
    lm_head_weight: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    final_logit_softcapping: Optional[float] = None,
    accum_dtype: Optional[torch.dtype] = None,
    **kwargs,
):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = F.liger_fused_linear_cross_entropy(
        hidden_states,
        lm_head_weight,
        target,
        reduction=reduction,
        ignore_index=ignore_index,
        softcap=final_logit_softcapping,
        accum_dtype=accum_dtype,
    )
    if reduction == "sum":
        loss = loss / num_items_in_batch

    return loss


def LigerForCausalLMLoss(
    hidden_states,
    lm_head_weight,
    labels,
    hidden_size: int,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    final_logit_softcapping: Optional[float] = None,
    **kwargs,
):
    # Skip upcast since intermediate values for the loss are all fp32 in kernel
    if shift_labels is None:
        # Shift so that token < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    hidden_states = hidden_states.view(-1, hidden_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(hidden_states.device)
    loss = fixed_fused_linear_cross_entropy(
        hidden_states,
        lm_head_weight,
        shift_labels,
        num_items_in_batch,
        ignore_index,
        final_logit_softcapping,
        **kwargs,
    )
    return loss



#########################
### For matching only ###
#########################
def LigerForCausalLMLossMatchingScore(
    hidden_states,
    lm_head_weight,
    labels,
    hidden_size: int,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    final_logit_softcapping: Optional[float] = None,
    **kwargs,
):
    # Skip upcast since intermediate values for the loss are all fp32 in kernel
    # import ipdb; ipdb.set_trace() # labels.shape=[16,500], 
    if shift_labels is None:
        # Shift so that token < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index) #[16,501]
        shift_labels = labels[..., 1:].contiguous()
    B, T, H = hidden_states.shape
    # Flatten the tokens
    hidden_states = hidden_states.view(-1, hidden_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(hidden_states.device)
    loss_BT = fixed_fused_linear_cross_entropy_matching_score(
        hidden_states,
        lm_head_weight,
        shift_labels,
        num_items_in_batch,
        ignore_index,
        final_logit_softcapping,
        **kwargs,
    )
    loss_B_T = loss_BT.view([B, T]) # [16, 500]
    # import ipdb; ipdb.set_trace()
    #### TODO average based on labels and ignore_index to construct per_seq loss ####
    # Create mask for valid labels
    valid_mask = shift_labels.view(B,T) != ignore_index  # shape [B, T]
    # valid_mask_f = valid_mask.to(loss_B_T.dtype)
    # Sum only over valid positions
    loss_sum = (loss_B_T * valid_mask).sum(dim=1)
    # Count of valid positions per batch
    valid_count = valid_mask.sum(dim=1).clamp_min(1)
    # Average over valid positions for each batch
    loss = loss_sum / valid_count
    # import ipdb; ipdb.set_trace()
    return loss

def fixed_fused_linear_cross_entropy_matching_score(
    hidden_states: torch.Tensor,
    lm_head_weight: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    final_logit_softcapping: Optional[float] = None,
    **kwargs,
):
    reduction = None
    # reduction = "sum" if num_items_in_batch is not None else "mean"
    loss_BT = F.liger_fused_linear_cross_entropy_matching_score(
        hidden_states,
        lm_head_weight,
        target,
        reduction=reduction,
        ignore_index=ignore_index,
        softcap=final_logit_softcapping,
    )
    # import ipdb; ipdb.set_trace() #loss=2311.7, num_items_in_batch=8000=8 grad accum*2 per_device_batch*1 gpu * 500 non-ignore tokens
    # TODO if num_items_in_batch=0, don't divide, or just divide by any because loss is already 0
    # if reduction == "sum":
    #     loss = loss / num_items_in_batch

    return loss_BT
