import torch
import torch.nn.functional as F

from pipeline.registry import registry


@registry.register_optimizer("qa_loss_v1")
def get_qa_loss_v1(
    txt_qr_logits,
    obj_qr_post_logits,
    obj_qr_pre_logits,
    obj_qr_raw_logits,
    og3d_logits,
    answer_scores,
    tgt_object_label,
    tgt_object_id,
    obj_labels,
    obj_masks,
    answer_label,
):
    og3d_logits = og3d_logits.masked_fill_(og3d_logits == -float("inf"), 0)
    og3d_loss = F.binary_cross_entropy_with_logits(
        og3d_logits, tgt_object_id.float(), reduction="sum", weight=obj_masks
    ) / float(tgt_object_id.shape[0])
    txt_qr_loss = F.binary_cross_entropy_with_logits(
        txt_qr_logits, tgt_object_label.float(), reduction="sum"
    ) / float(tgt_object_label.shape[0])
    obj_qr_raw_loss = (
        F.cross_entropy(obj_qr_raw_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_pre_loss = (
        F.cross_entropy(obj_qr_pre_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_post_loss = (
        F.cross_entropy(obj_qr_post_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    answer_loss = F.binary_cross_entropy_with_logits(
        answer_scores, answer_label.float(), reduction="sum"
    ) / answer_scores.shape[0]
    total_loss = (
        og3d_loss
        + txt_qr_loss
        + obj_qr_raw_loss
        + obj_qr_pre_loss
        + obj_qr_post_loss
        + answer_loss
    )
    return (
        total_loss,
        og3d_loss,
        txt_qr_loss,
        obj_qr_raw_loss,
        obj_qr_pre_loss,
        obj_qr_post_loss,
        answer_loss,
    )


@registry.register_optimizer("qa_loss_focal")
def get_qa_loss_focal(
    txt_qr_logits,
    obj_qr_post_logits,
    obj_qr_pre_logits,
    obj_qr_raw_logits,
    og3d_logits,
    answer_scores,
    tgt_object_label,
    tgt_object_id,
    obj_labels,
    obj_masks,
    answer_label,
    gamma: float = 2.0,
    alpha: float = 1.0,
):
    og3d_logits = og3d_logits.masked_fill_(og3d_logits == -float("inf"), 0)
    og3d_loss = F.binary_cross_entropy_with_logits(
        og3d_logits, tgt_object_id.float(), reduction="sum", weight=obj_masks
    ) / float(tgt_object_id.shape[0])
    txt_qr_loss = F.binary_cross_entropy_with_logits(
        txt_qr_logits, tgt_object_label.float(), reduction="sum"
    ) / float(tgt_object_label.shape[0])
    obj_qr_raw_loss = (
        F.cross_entropy(obj_qr_raw_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_pre_loss = (
        F.cross_entropy(obj_qr_pre_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_post_loss = (
        F.cross_entropy(obj_qr_post_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()

    probs = torch.softmax(answer_scores, dim=-1)
    labels = answer_label.argmax(dim=-1)
    p_t = probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    focal_weight = (1 - p_t) ** gamma
    ce_loss = F.cross_entropy(answer_scores, labels, reduction="none")
    answer_loss = (alpha * focal_weight * ce_loss).mean()

    total_loss = (
        og3d_loss
        + txt_qr_loss
        + obj_qr_raw_loss
        + obj_qr_pre_loss
        + obj_qr_post_loss
        + answer_loss
    )
    return (
        total_loss,
        og3d_loss,
        txt_qr_loss,
        obj_qr_raw_loss,
        obj_qr_pre_loss,
        obj_qr_post_loss,
        answer_loss,
    )


@registry.register_optimizer("qa_loss_focal_smooth")
def get_qa_loss_focal_smooth(
    txt_qr_logits,
    obj_qr_post_logits,
    obj_qr_pre_logits,
    obj_qr_raw_logits,
    og3d_logits,
    answer_scores,
    tgt_object_label,
    tgt_object_id,
    obj_labels,
    obj_masks,
    answer_label,
    gamma: float = 2.0,
    alpha: float = 1.0,
    smoothing: float = 0.1,
    answer_loss_weight: float = 1.0,
    aux_loss_weight: float = 1.0,
):
    og3d_logits = og3d_logits.masked_fill_(og3d_logits == -float("inf"), 0)
    og3d_loss = F.binary_cross_entropy_with_logits(
        og3d_logits, tgt_object_id.float(), reduction="sum", weight=obj_masks
    ) / float(tgt_object_id.shape[0])
    txt_qr_loss = F.binary_cross_entropy_with_logits(
        txt_qr_logits, tgt_object_label.float(), reduction="sum"
    ) / float(tgt_object_label.shape[0])
    obj_qr_raw_loss = (
        F.cross_entropy(obj_qr_raw_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_pre_loss = (
        F.cross_entropy(obj_qr_pre_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()
    obj_qr_post_loss = (
        F.cross_entropy(obj_qr_post_logits.permute(0, 2, 1), obj_labels, reduction="none")
        * obj_masks
    ).sum() / obj_masks.sum()

    probs = torch.softmax(answer_scores, dim=-1)
    log_probs = F.log_softmax(answer_scores, dim=-1)
    num_answers = answer_scores.size(-1)
    labels = answer_label.argmax(dim=-1)
    smooth_labels = torch.zeros_like(probs)
    smooth_labels.fill_(smoothing / num_answers)
    smooth_labels.scatter_(1, labels.unsqueeze(1), 1.0 - smoothing + smoothing / num_answers)
    p_t = probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    focal_weight = (1 - p_t) ** gamma
    ce_loss = -(smooth_labels * log_probs).sum(dim=-1)
    answer_loss = (alpha * focal_weight * ce_loss).mean()

    aux_loss = og3d_loss + txt_qr_loss + obj_qr_raw_loss + obj_qr_pre_loss + obj_qr_post_loss
    total_loss = aux_loss_weight * aux_loss + answer_loss_weight * answer_loss

    return (
        total_loss,
        og3d_loss,
        txt_qr_loss,
        obj_qr_raw_loss,
        obj_qr_pre_loss,
        obj_qr_post_loss,
        answer_loss,
    )


def subspace_orthogonality_loss(subspace_vectors: torch.Tensor, weight: float = 0.1) -> torch.Tensor:
    normalized = F.normalize(subspace_vectors, dim=-1)
    cosine_matrix = torch.mm(normalized, normalized.t())
    k = cosine_matrix.size(0)
    triu_indices = torch.triu_indices(k, k, offset=1, device=cosine_matrix.device)
    off_diagonal = cosine_matrix[triu_indices[0], triu_indices[1]]
    return (off_diagonal ** 2).mean() * weight
