# %%
import numpy as np
from scipy.stats import entropy
import torch
from classification_utils import sigmoid_normalize_embeddings

# %%
def test_sigmoid_normalize_location_embeddings_entropy():
    """
    Test the entropy behavior of sigmoid_normalize_location_embeddings.
    The entropy should increase as alpha decreases, with alpha=0 resulting
    in the entropy of a uniform vector and alpha=1 resulting in the entropy
    of the original location embeddings.
    """
    # Create a sample location_embeddings array
    location_embeddings = np.array([
        # [0.1, 0.2, 0.3, 0.4],
        # [0.4, 0.3, 0.2, 0.1],
        # [0.25, 0.25, 0.25, 0.25],
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
        [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
        [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
        [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0],
        [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
        [8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5],
        [0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9],
        [0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1],
    ])

    # Compute entropy for different alpha values
    alpha_values = [0, 0.5, 1]
    entropies = []

    for alpha in alpha_values:
        normalized_embeddings = sigmoid_normalize_embeddings(location_embeddings, alpha=alpha)
        # Normalize embeddings to ensure they sum to 1
        normalized_embeddings = normalized_embeddings / np.sum(normalized_embeddings, axis=1, keepdims=True)
        # Compute entropy for each row
        row_entropies = [entropy(row) for row in normalized_embeddings]
        entropies.append(np.mean(row_entropies))

    # Compute entropy of the uniform vector
    uniform_vector = np.ones_like(location_embeddings) / location_embeddings.shape[1]
    uniform_entropy = entropy(uniform_vector[0])

    # Compute entropy of the original location embeddings
    sigmoid_embeddings = torch.sigmoid(torch.from_numpy(location_embeddings).float()).numpy()
    sigmoid_embeddings = sigmoid_embeddings / np.sum(sigmoid_embeddings, axis=1, keepdims=True)
    sigmoid_entropy = np.mean([entropy(row) for row in sigmoid_embeddings])

    print("Entropies for different alpha values:")
    for alpha, ent in zip(alpha_values, entropies):
        print(f"  alpha={alpha}: entropy={ent:.6f}")
    print(f"Uniform vector entropy: {uniform_entropy:.6f}")
    print(f"Sigmoid embedding entropy: {sigmoid_entropy:.6f}")

    print("\nOriginal and transformed vectors for each alpha:")
    for i, row in enumerate(location_embeddings):
        print(f"\nSample {i+1}:")
        print(f"  Original: {row}")
        for alpha in alpha_values:
            normalized = sigmoid_normalize_embeddings(row[np.newaxis, :], alpha=alpha)
            normalized = normalized / np.sum(normalized, axis=1, keepdims=True)
            print(f"  alpha={alpha}: {normalized[0]}")

    # Assertions
    assert np.isclose(entropies[0], uniform_entropy, atol=1e-5), "Entropy for alpha=0 should match uniform vector entropy."
    assert np.isclose(entropies[-1], sigmoid_entropy, atol=1e-5), "Entropy for alpha=1 should match entropy of sigmoid embedding."
    assert entropies[0] > entropies[1] > entropies[2], "Entropy should decrease as alpha increases."

    print("All tests passed!")

# %%
# Run the test
test_sigmoid_normalize_location_embeddings_entropy()
# %%
# Example: create a torch tensor with some entries
example_tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print("Example tensor:\n", example_tensor)
# %%
example_tensor.min(dim=1)
# %%
