import torch
from torchtyping import TensorType

def euclidean_distance_tensor(height, width):
    # Generate grid of points
    x = torch.arange(width)
    y = torch.arange(height)
    xx, yy = torch.meshgrid(y, x, indexing='ij')
    
    # Flatten the grids
    points = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    
    # Compute pairwise distance using broadcasting
    diff = points[:, None, :] - points[None, :, :]
    distance_tensor = torch.sqrt(torch.sum(diff**2, dim=-1))
    
    return distance_tensor


def correlation_matrix(tensor: TensorType["A", "B"]) -> TensorType["B", "B"]:
    # Subtract the mean from each column
    mean = torch.mean(tensor, dim=0, keepdim=True)
    tensor_centered = tensor - mean

    # Compute the covariance matrix
    cov_matrix = torch.mm(tensor_centered.T, tensor_centered) / (tensor.size(0) - 1)

    # Compute the standard deviation of each column
    std_dev = torch.sqrt(torch.diag(cov_matrix))

    # Compute the correlation matrix
    corr_matrix = cov_matrix / (std_dev[:, None] * std_dev[None, :])

    return corr_matrix

def get_random_subset(tensor, fraction: float):
    # Ensure the fraction is between 0 and 1
    if not (0 <= fraction <= 1):
        raise ValueError("Fraction must be between 0 and 1.")

    # Calculate the number of elements to select
    subset_size = int(fraction * tensor.size(0))

    # Generate random indices
    random_indices = torch.randperm(tensor.size(0))[:subset_size]

    # Return the subset
    return tensor[random_indices]