import torch


def cosine_distance_torch(matrix_a: torch.Tensor, matrix_b: torch.Tensor) -> torch.Tensor:
    """
    Compute cosine distance between two feature matrices using only basic math operations.

    Parameters
    ----------
    matrix_a : torch.Tensor
        Tensor of shape (n_features, n_samples_a)
    matrix_b : torch.Tensor
        Tensor of shape (n_features, n_samples_b)

    Returns
    -------
    torch.Tensor
        Cosine distance matrix of shape (n_samples_a, n_samples_b)
    """
    if matrix_a.shape[1] != matrix_b.shape[1]:
        raise ValueError("Both matrices must have the same number of features (rows).")

    # --- Manual normalization ---
    # Compute squared sums along features
    sum_sq_a = (matrix_a * matrix_a).sum(dim=1, keepdim=True)  # shape (1, n_samples_a)
    sum_sq_b = (matrix_b * matrix_b).sum(dim=1, keepdim=True)  # shape (1, n_samples_b)

    # Take square root for Euclidean norm
    norm_a = torch.sqrt(sum_sq_a)
    norm_b = torch.sqrt(sum_sq_b)

    # Avoid division by zero
    norm_a[norm_a == 0] = 1e-12
    norm_b[norm_b == 0] = 1e-12

    # Normalize each column (sample)
    normed_a = matrix_a / norm_a
    normed_b = matrix_b / norm_b

    # Cosine similarity = dot product of normalized vectors
    cosine_sim = torch.matmul(normed_a, normed_b.T)

    # Account for nummerical instability
    cosine_sim = cosine_sim.clamp(min=-1.0, max= 1.0)
    # Cosine distance
    cosine_dist = 1 - cosine_sim
    return cosine_dist

# A = torch.rand(10, 5)  # 5 features × 10 samples
# B = torch.rand(8, 5)   # 5 features × 8 samples
#
# distances = cosine_distance_torch(A, B)
# print(distances.shape)
# print(distances)
#
# distances = cosine_distance_torch(A, A)
# print(distances.shape)
# print(distances)