import numpy as np
import math

numIters = 1

# Matrix dimensions
rows = 4 # Number of rows for each group
cols = 4 # Number of columns

# Set seed for reproducibility
np.random.seed(1)



# Rank parameter
k = 1
allRats=[]

def frobenius_noise(A, vectors):
    # Compute the projection of A onto the subspace spanned by the vectors
    projection_matrix = np.matmul(np.linalg.pinv(vectors), vectors)  # Pseudoinverse for non-square matrices
    A_proj = np.matmul(A,projection_matrix)

    # Compute the Frobenius noise (Frobenius norm of the difference)
    noise = np.linalg.norm(A - A_proj, 'fro')
    return noise

for numIter in range(numIters):
##    opt_vectors = [[1,0,0,0],[0,1,0,0]]
##    A1 = [[4,10,0,0],[2,-1,0,0],[0,3,0,0],[-2,0,0,0]]
##    # Generate a random Gaussian matrix
##    mean = 0  # Mean of the Gaussian distribution
##    std_dev = 0.001  # Standard deviation of the Gaussian distribution
##    random_gaussian_matrix = np.random.normal(mean, std_dev, size=(rows, cols))
##    A1 = A1 + random_gaussian_matrix
##    A2 = [[2,7,0,0],[-10,9,0,0],[1,4,0,0],[5,0,0,0]]
##    # Generate a random Gaussian matrix
##    mean = 0  # Mean of the Gaussian distribution
##    std_dev = 0.00001  # Standard deviation of the Gaussian distribution
##    random_gaussian_matrix = np.random.normal(mean, std_dev, size=(rows, cols))
##    A2 = A2 + random_gaussian_matrix
    A1 = [[1,0,0,0]]
    A2 = [[0,1,0,0]]
    A3 = [[0,1,0,0]]
    A4 = [[0,1,0,0]]
    opt_vectors = [[1/2,1/2,0,0]]
    fair_loss = max(frobenius_noise(A1, opt_vectors), frobenius_noise(A2, opt_vectors), frobenius_noise(A3, opt_vectors), frobenius_noise(A4, opt_vectors))

    A = np.vstack((A1, A2, A3, A4))
    # Compute the Singular Value Decomposition (SVD) of A
    U, S, VT = np.linalg.svd(A, full_matrices=False)
    vectors = VT[:k,:]

    standard_loss = max(frobenius_noise(A1, vectors), frobenius_noise(A2, vectors), frobenius_noise(A3, vectors), frobenius_noise(A4, vectors))
    allRats.append(math.pow(fair_loss/standard_loss,2))
print(np.min(allRats))
print(np.mean(allRats))
