import torch as t
from loguru import logger


@t.no_grad()
def sinkhorn_balancing_dual(
    pre_act_features_BsF: t.Tensor,
    feature_capacities_F: t.Tensor,
    regularisation_eps: float = 0.01,
    max_iter: int = 100,
    tolerance: float = 1e-9,
) -> t.Tensor:
    """
    Sinkhorn algorithm using dual variables f and g.

    Args:
        logits (t.Tensor): Router logits of shape (batch, seq_len, num_features).
        feature_capacities (t.Tensor): Feature capacities of shape (num_features,).
        regularisation_eps (float): Regularization parameter (> 0).
        max_iter (int, optional): Maximum number of iterations.
        tolerance (float, optional): Convergence tolerance.

    Returns:
        t.Tensor: Assignment matrix Pi of shape (T, E).
    """

    bs, num_features = pre_act_features_BsF.shape

    total_feature_capacity = feature_capacities_F.sum()

    assert t.isclose(
        t.tensor(bs * 1.0), total_feature_capacity, atol=1e-6
    ), "Total token count must equal total feature capacity."

    # Initialize dual variables
    f_Bs = t.zeros(bs, device=pre_act_features_BsF.device)
    g_F = t.zeros(num_features, device=pre_act_features_BsF.device)

    # Marginal distributions
    tokens_marginal_Bs = t.ones(bs, device=pre_act_features_BsF.device)
    features_marginal_F = feature_capacities_F

    tokens_log_marginal_Bs = t.log(tokens_marginal_Bs)
    features_log_marginal_F = t.log(features_marginal_F)

    for _ in range(max_iter):
        f_prev = f_Bs.clone()

        # Update f (rows)
        f_Bs = regularisation_eps * (
            tokens_log_marginal_Bs
            - t.logsumexp(
                (pre_act_features_BsF + g_F.unsqueeze(0)) / regularisation_eps, dim=1
            )
        )

        # Update g (columns)
        g_F = regularisation_eps * (
            features_log_marginal_F
            - t.logsumexp(
                (pre_act_features_BsF + f_Bs.unsqueeze(1)) / regularisation_eps, dim=0
            )
        )

        # Check convergence
        # error = t.norm(f_Bs - f_prev, p=1)
        error = t.mean(t.abs(f_Bs - f_prev))
        print(error)
        if error < tolerance:
            break

    # Compute the assignment matrix Pi
    assignments_BsF = t.exp(
        (pre_act_features_BsF + f_Bs.unsqueeze(1) + g_F.unsqueeze(0)) / regularisation_eps
    )

    return assignments_BsF


if __name__ == "__main__":
    # Example data
    bs = 100  # Total number of tokens
    num_features = 10  # Number of features

    # Random logits
    pre_act_features_BsF = t.randn(bs, num_features)

    # Feature capacities (sum should equal bs)
    feature_capacities_F = t.tensor(
        [10, 15, 5, 20, 10, 10, 5, 5, 10, 10], dtype=pre_act_features_BsF.dtype
    )

    # Ensure total feature capacity equals bs
    assert t.isclose(
        feature_capacities_F.sum(), t.tensor(bs, dtype=feature_capacities_F.dtype)
    )

    # Regularization parameter
    regularisation_eps = 0.1

    # Run the Sinkhorn algorithm
    assignments_BsF = sinkhorn_balancing_dual(
        pre_act_features_BsF, feature_capacities_F, regularisation_eps=regularisation_eps
    )

    # Verify marginal constraints
    row_sums = assignments_BsF.sum(dim=1)
    col_sums = assignments_BsF.sum(dim=0)

    assert t.allclose(
        row_sums, t.ones(bs), atol=1e-3
    ), f"Row sums: {row_sums}, assignments: {assignments_BsF}, expected 1"
    assert t.allclose(col_sums, feature_capacities_F, atol=1e-3)
