import torch

# Set device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def quota_loss(
    T: torch.Tensor, S_X: torch.Tensor, S_Y: torch.Tensor, F: torch.Tensor
) -> torch.Tensor:
    """
    Quota loss between two distributions X and Y matched through the optimal
    transport plan T with respect to their respective sensitive attributes S_X
    and S_Y.

    Parameters:
    ----------
    T (torch.Tensor): The optimal transport plan.
    S_X (torch.Tensor): The sensitive attribute for the first distribution.
    S_Y (torch.Tensor): The sensitive attribute for the second distribution.
    F (torch.Tensor): Target fairness matrix.
    """
    n_s_x, n_s_y = F.shape

    # Create one-hot encodings for sensitive attributes
    S_X_onehot = torch.nn.functional.one_hot(
        S_X.long(), num_classes=n_s_x
    ).float()  # shape: (n_x, n_s_x)
    S_Y_onehot = torch.nn.functional.one_hot(
        S_Y.long(), num_classes=n_s_y
    ).float()  # shape: (n_y, n_s_y)

    # Compute joint distribution using matrix multiplication
    joint_distribution = S_X_onehot.T @ T @ S_Y_onehot  # shape: (n_s_x, n_s_y)

    joint_distribution = joint_distribution.to(device)
    F = F.to(device)

    return torch.sum(torch.square(joint_distribution - F))


def weighted_quota_loss(
    T: torch.Tensor,
    C: torch.Tensor,
    S_X: torch.Tensor,
    S_Y: torch.Tensor,
    F: torch.Tensor,
) -> torch.Tensor:
    """
    Weighted quota loss between two distributions X and Y matched through the
    optimal transport plan T with respect to their respective sensitive
    attributes S_X and S_Y.

    Parameters:
    ----------
    T (torch.Tensor): The optimal transport plan.
    C (torch.Tensor): The weighting matrix matrix.
    S_X (torch.Tensor): The sensitive attribute for the first distribution.
    S_Y (torch.Tensor): The sensitive attribute for the second distribution.
    F (torch.Tensor): Target fairness matrix.
    return_all (bool): Whether to return all intermediate values.

    Returns:
    -------
    torch.Tensor: The cost per group loss.
    """
    n_s_x, n_s_y = F.shape

    # Create one-hot encodings for sensitive attributes
    S_X_onehot = torch.nn.functional.one_hot(
        S_X.long(), num_classes=n_s_x
    ).float()  # shape: (n_x, n_s_x)
    S_Y_onehot = torch.nn.functional.one_hot(
        S_Y.long(), num_classes=n_s_y
    ).float()  # shape: (n_y, n_s_y)

    # Compute groupwise cost using matrix multiplication
    weighted_cost = C * T  # element-wise multiplication
    groupwise_cost = (
        S_X_onehot.T @ weighted_cost @ S_Y_onehot
    )  # shape: (n_s_x, n_s_y)

    groupwise_cost = groupwise_cost.to(device)
    F = F.to(device)

    return torch.sum(torch.square(groupwise_cost - F))
