import torch
import numpy as np

def gram_schmidt_matrix_torch(matrices):
    ortho_matrices = []
    for matrix in matrices:
        m = matrix.clone()
        for mat in ortho_matrices:
            proj = torch.sum(mat*m)*mat
            m -= proj
        ortho_matrices.append(m/torch.linalg.norm(m))
    return ortho_matrices

def gram_schmidt_matrix_np(matrices):
    ortho_matrices = []
    for matrix in matrices:
        m = matrix.copy()
        for mat in ortho_matrices:
            proj = np.sum(mat*m)*mat
            m -= proj
        ortho_matrices.append(m/np.linalg.norm(m))
    return ortho_matrices

def greedy_match(W1, W2):
    dots = W1@W2.T
    magnitude1 = torch.sum(W1**2, dim=1)
    magnitude2 = torch.sum(W2**2, dim=1)
    mags = torch.sqrt(magnitude1[:,None]*magnitude2[None,:])
    rel_sims = dots/mags
    return torch.argmax(rel_sims,dim=1)