import numpy as np
import torch

def sigmoid_normalize_embeddings(embeddings, alpha=0.5):
    """
    Applies sigmoid normalization to embeddings.

    Args:
        embeddings (np.ndarray): The encoder outputs.
        alpha (float): Scaling factor for the sigmoid.

    Returns:
        np.ndarray: Sigmoid-normalized embeddings.
    """
    embeddings_torch = torch.from_numpy(embeddings)
    normalized = torch.sigmoid(
        alpha * (embeddings_torch - embeddings_torch.min(dim=1, keepdim=True).values) + 
        embeddings_torch.min(dim=1, keepdim=True).values
        )

    return normalized.numpy()

def top_k_accuracy(logits, labels, k=5):
    """
    Computes the top-k accuracy for the given logits and labels.

    Args:
        logits (np.ndarray): The predicted logits.
        labels (np.ndarray): The true labels.
        k (int): The number of top predictions to consider.

    Returns:
        float: The top-k accuracy.
    """
    top_k_preds = np.argsort(logits, axis=1)[:, -k:]
    correct = np.sum(np.any(top_k_preds == labels[:, None], axis=1))
    return correct / len(labels)

def combined_vision_location_encoding_evaluation(location_embeddings, vision_embeddings, labels, k=5):
    """
    Evaluates the combined vision-location encoding model.

    Args:
        location_embeddings (np.ndarray): The location encoder outputs.
        vision_embeddings (np.ndarray): The vision model outputs.
        labels (np.ndarray): The true labels.
        k (int): The number of top predictions to consider.

    Returns:
        tuple: Top-1 and top-k accuracies for combined, vision, and location models.
    """

    # OPTION A
    # Normalize location embeddings
    # location_embeddings = sigmoid_normalize_embeddings(location_embeddings, alpha=0.78125)
    location_embeddings = sigmoid_normalize_embeddings(location_embeddings, alpha=1.0)

    # Combine location and vision embeddings
    combined_logits = location_embeddings * vision_embeddings

    # OPTION B
    # create combined embeddings by adding location and vision embeddings
    # combined_embeddings = location_embeddings + vision_embeddings

    # compute logits by applying softmax to the combined embeddings
    # combined_logits = torch.softmax(torch.from_numpy(combined_embeddings), dim=1).numpy()
    
    top_1_combined_accuracy = top_k_accuracy(combined_logits, labels, k=1)
    top_k_combined_accuracy = top_k_accuracy(combined_logits, labels, k=k)
    top_1_vision_accuracy = top_k_accuracy(vision_embeddings, labels, k=1)
    top_k_vision_accuracy = top_k_accuracy(vision_embeddings, labels, k=k)
    top_1_location_accuracy = top_k_accuracy(location_embeddings, labels, k=1)
    top_k_location_accuracy = top_k_accuracy(location_embeddings, labels, k=k)
    return (
        top_1_combined_accuracy,
        top_k_combined_accuracy,
        top_1_vision_accuracy,
        top_k_vision_accuracy,
        top_1_location_accuracy,
        top_k_location_accuracy,
    )
