import torch
from sentence_transformers import util

from box.box_wrapper import CenterDeltaBoxTensor, CenterScalarDeltaBoxTensor


def similarity_function(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    sentence1_embeddings = CenterDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterDeltaBoxTensor.from_split(sentence2_embeddings)
    shape1 = sentence1_embeddings.data.shape
    shape2 = sentence2_embeddings.data.shape

    if len(shape1) == 2:
        sentence1_embeddings.data = sentence1_embeddings[None, :, :]
        shape1 = sentence1_embeddings.data.shape
    if len(shape2) == 2:
        sentence2_embeddings.data = sentence2_embeddings[None, :, :]
        shape2 = sentence2_embeddings.data.shape

    sentence1_embeddings.data = sentence1_embeddings.data[:, None, :, :]
    sentence2_embeddings.data = sentence2_embeddings.data[None, :, :, :]

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    return intersection_similarity


def similarity_function_square(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    sentence1_embeddings = CenterScalarDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterScalarDeltaBoxTensor.from_split(sentence2_embeddings)
    shape1 = sentence1_embeddings.data.shape
    shape2 = sentence2_embeddings.data.shape

    if len(shape1) == 2:
        sentence1_embeddings.data = sentence1_embeddings[None, :, :]
        shape1 = sentence1_embeddings.data.shape
    if len(shape2) == 2:
        sentence2_embeddings.data = sentence2_embeddings[None, :, :]
        shape2 = sentence2_embeddings.data.shape

    sentence1_embeddings.data = sentence1_embeddings.data[:, None, :, :]
    sentence2_embeddings.data = sentence2_embeddings.data[None, :, :, :]

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    return intersection_similarity


def similarity_function_entailment_square(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    """calculate entailment similarity in a axb manner, and sentence2_embeddings is
    the smaller one.

    Args:
        sentence1_embeddings ():
        sentence2_embeddings ():
        volume_temp ():
        intersection_temp ():

    Returns:
        [TODO:return]
    """
    sentence1_embeddings = CenterScalarDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterScalarDeltaBoxTensor.from_split(sentence2_embeddings)
    shape1 = sentence1_embeddings.data.shape
    shape2 = sentence2_embeddings.data.shape

    if len(shape1) == 2:
        sentence1_embeddings.data = sentence1_embeddings[None, :, :]
        shape1 = sentence1_embeddings.data.shape
    if len(shape2) == 2:
        sentence2_embeddings.data = sentence2_embeddings[None, :, :]
        shape2 = sentence2_embeddings.data.shape

    sentence1_embeddings.data = sentence1_embeddings.data[:, None, :, :]
    sentence2_embeddings.data = sentence2_embeddings.data[None, :, :, :]

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    log_volume_sentence2 = sentence2_embeddings.log_soft_volume_adjusted(
        volume_temp=volume_temp, intersection_temp=intersection_temp
    )

    conditional_prob = intersection_similarity - log_volume_sentence2

    return conditional_prob


def similarity_function_pairwise(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    sentence1_embeddings = CenterDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterDeltaBoxTensor.from_split(sentence2_embeddings)

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    return intersection_similarity


def similarity_function_entailment_pairwise(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    sentence1_embeddings = CenterDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterDeltaBoxTensor.from_split(sentence2_embeddings)

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    log_volume_sentence2 = sentence2_embeddings.log_soft_volume_adjusted(
        volume_temp=volume_temp, intersection_temp=intersection_temp
    )

    conditional_prob = intersection_similarity - log_volume_sentence2
    return conditional_prob


def similarity_function_entailment(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    """calculate entailment similarity in a axb manner, and sentence2_embeddings is
    the smaller one.

    Args:
        sentence1_embeddings ():
        sentence2_embeddings ():
        volume_temp ():
        intersection_temp ():

    Returns:
        [TODO:return]
    """
    sentence1_embeddings = CenterDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterDeltaBoxTensor.from_split(sentence2_embeddings)
    shape1 = sentence1_embeddings.data.shape
    shape2 = sentence2_embeddings.data.shape

    if len(shape1) == 2:
        sentence1_embeddings.data = sentence1_embeddings[None, :, :]
        shape1 = sentence1_embeddings.data.shape
    if len(shape2) == 2:
        sentence2_embeddings.data = sentence2_embeddings[None, :, :]
        shape2 = sentence2_embeddings.data.shape

    sentence1_embeddings.data = sentence1_embeddings.data[:, None, :, :]
    sentence2_embeddings.data = sentence2_embeddings.data[None, :, :, :]

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    log_volume_sentence2 = sentence2_embeddings.log_soft_volume_adjusted(
        volume_temp=volume_temp, intersection_temp=intersection_temp
    )

    conditional_prob = intersection_similarity - log_volume_sentence2

    return conditional_prob


def similarity_function_entailment_opposite(
    sentence1_embeddings,
    sentence2_embeddings,
    volume_temp=1.0,
    intersection_temp=0.001,
):
    sentence1_embeddings = CenterDeltaBoxTensor.from_split(sentence1_embeddings)
    sentence2_embeddings = CenterDeltaBoxTensor.from_split(sentence2_embeddings)
    shape1 = sentence1_embeddings.data.shape
    shape2 = sentence2_embeddings.data.shape

    if len(shape1) == 2:
        sentence1_embeddings.data = sentence1_embeddings[None, :, :]
        shape1 = sentence1_embeddings.data.shape
    if len(shape2) == 2:
        sentence2_embeddings.data = sentence2_embeddings[None, :, :]
        shape2 = sentence2_embeddings.data.shape

    sentence1_embeddings.data = sentence1_embeddings.data[:, None, :, :]
    sentence2_embeddings.data = sentence2_embeddings.data[None, :, :, :]

    intersection_similarity = sentence1_embeddings.gumbel_intersection_log_volume(
        sentence2_embeddings,
        volume_temp=volume_temp,
        intersection_temp=intersection_temp,
    )

    log_volume_sentence1 = sentence1_embeddings.log_soft_volume_adjusted(
        volume_temp=volume_temp, intersection_temp=intersection_temp
    )

    conditional_prob = intersection_similarity - log_volume_sentence1

    return conditional_prob


def vector_entailment_similarity_csdelta(
    embedding1: torch.Tensor,
    embedding2: torch.Tensor,
) -> torch.Tensor:
    similarity_scores = util.cos_sim(embedding1, embedding2)

    norms_1 = torch.linalg.norm(embedding1, ord=1, axis=1)  # shape: (a,)
    norms_2 = torch.linalg.norm(embedding2, ord=1, axis=1)  # shape: (b,)

    difference = norms_1[:, None] - norms_2[None, :]

    return similarity_scores * difference


def vector_entailment_similarity_csdelta_pairwise(
    embedding1: torch.Tensor,
    embedding2: torch.Tensor,
) -> torch.Tensor:
    similarity_scores = util.cos_sim(embedding1, embedding2)

    norms_1 = torch.linalg.norm(embedding1, ord=1, axis=1)  # shape: (a,)
    norms_2 = torch.linalg.norm(embedding2, ord=1, axis=1)  # shape: (b,)

    difference = norms_1[:, None] - norms_2[None, :]

    return similarity_scores * difference
