import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava import conversation as conversation_lib
from PIL import Image

image_token_ids = [IMAGE_TOKEN_INDEX, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX]

class DataArguments:
    is_multimodal: bool = True
    mm_use_im_start_end: bool = False
    data_path: str = 'llava_instruction_data.json'

def get_kernel_size(image_size):
    if image_size[0] >=2000:
        kernel_size = 501
    elif image_size[0] >= 800:
        kernel_size = 301
    elif image_size[0] >= 500:
        kernel_size = 201
    else:
        kernel_size = 101
    return kernel_size

def compute_token_loss(
    model,
    input_ids,
    labels,
    image_tensor=None,
    image_size=None,
    attention_mask=None,
    criterion=None,
):
    """
    Fast path:
    - assumes tensors are already on the right device/dtype
    - avoids per-call tokenizer work and loss object creation
    - caller should wrap with torch.inference_mode()
    """
    if image_tensor is not None and image_tensor.dtype != torch.float16:
        image_tensor = image_tensor.to(dtype=torch.float16)

    outputs = model(
        input_ids=input_ids,
        labels=labels,
        images=image_tensor,
        image_sizes=image_size,
        output_logits=True,
    )

    _, _, attention_mask_, _, _, labels_ids_ = model.prepare_inputs_labels_for_multimodal(
        input_ids=input_ids,
        position_ids=None,
        attention_mask=attention_mask,
        past_key_values=None,
        labels=labels,
        images=image_tensor,
        image_sizes=image_size
    )

    labels_ids_ = torch.nn.functional.pad(labels_ids_, (0, 1), value=IGNORE_INDEX)
    labels_cleaned = labels_ids_[..., 1:]
    logits_cleaned = outputs["logits"]

    if criterion is None:
        criterion = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
    loss = criterion(logits_cleaned.reshape(-1, logits_cleaned.size(-1)), 
                    labels_cleaned.reshape(-1))
    loss = loss.view(image_tensor.size(0), -1)

    mask = labels_cleaned != IGNORE_INDEX
    mask = mask.bool().to(loss.device)
    
    w_loss = loss[0::2]
    w_mask = mask[0::2]
    w_attention_mask = attention_mask_[0::2]
    wo_loss = loss[1::2]
    wo_mask = mask[1::2]
    wo_attention_mask = attention_mask_[1::2]

    del logits_cleaned

    return w_loss, w_mask, w_attention_mask, wo_loss, wo_mask, wo_attention_mask