import torch
import numpy as np

#from scipy.linalg import qr, polar

def polar_factorization_torch(A):
    A = torch.tensor(A)
    # Perform Singular Value Decomposition
    U, S, V = torch.linalg.svd(A)
    Vh = V.conj().transpose(-2, -1)  # Equivalent to V' in MATLAB
    P = V @ torch.diag_embed(S) @ Vh
    U = U @ Vh
    return U.numpy(), P.numpy()

def polar_factorization(A):
    # Perform Singular Value Decomposition
    W, S, Vh = np.linalg.svd(A)
    P = Vh.T @ np.diag(S) @ Vh # Here, Vh is equivalent to MATLAB's V'
    U = W @ Vh
    return U, P

def rotate_for_sparsity(S):
    m, n = S[0].shape

    # Initialize Q1 and Q2 with QR decomposition of random matrices
    Q1, _ = np.linalg.qr(np.random.randn(m, m))
    Q2, _ = np.linalg.qr(np.random.randn(n, n))

    for iter in range(20):
        grad1 = np.zeros((m, m))
        grad2 = np.zeros((n, n))

        # Lazy implementation, could be faster/refactored
        for i in range(len(S)):
            Si = S[i]
            Siconj = Q1.T @ Si @ Q2

            for j in range(i + 1, len(S)): # conj().T
                Sj = S[j]
                Sjconj = Q1.T @ Sj @ Q2

                grad1 += Si @ Q2 @ ((Sjconj.T) * (Sjconj.T) * (Siconj.T)) \
                        + Sj @ Q2 @ ((Sjconj.T) * (Siconj.T) * (Siconj.T))

                grad2 += Si.T @ Q1 @ (Sjconj * Sjconj * Siconj) \
                        + Sj.T @ Q1 @ (Sjconj * Siconj * Siconj)

        Q1, _ = polar_factorization(grad1) # polar(grad1, side='right') # polar_factorization(grad1)
        Q2, _ = polar_factorization(grad2) # polar(grad2, side='right') # polar_factorization(grad2)

    return Q1, Q2
    

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    # Setting up the collection of block matrices
    nrepeat = 3
    blocksize = 5

    # Initialize the S list with zero tensors
    S = [torch.zeros(2 * blocksize, 2 * blocksize) for _ in range(2 * nrepeat)]

    for i in range(nrepeat):
        S[i][:blocksize, :blocksize] = torch.randn(blocksize, blocksize)
        S[i + nrepeat][blocksize:, blocksize:] = torch.randn(blocksize, blocksize)

    # Hide that structure by conjugating with random orthogonal matrices
    U, _ = torch.linalg.qr(torch.randn(2 * blocksize, 2 * blocksize))
    V, _ = torch.linalg.qr(torch.randn(2 * blocksize, 2 * blocksize))

    S_before_hiding = [s.clone().numpy() for s in S]

    for i in range(len(S)):
        S[i] = U.T @ S[i] @ V

        
    # Convert PyTorch tensors to NumPy arrays
    S = [s.cpu().numpy() for s in S] 

    # Test if the algorithm works
    Q1, Q2 = rotate_for_sparsity(S)

    # Plot original matrices
    fig, axs = plt.subplots(3, len(S), figsize=(15, 6))

    for i, Si in enumerate(S_before_hiding):
        axs[0, i].imshow(Si, cmap='viridis')
        axs[0, i].axis('equal')
        axs[0, i].axis('off')
        axs[0, i].set_title(f'Before conjugating S[{i+1}]')

    for i, Si in enumerate(S):
        axs[1, i].imshow(Si, cmap='viridis')
        axs[1, i].axis('equal')
        axs[1, i].axis('off')
        axs[1, i].set_title(f'Original S[{i+1}]')

    # Plot transformed matrices
    for i, Si in enumerate(S):
        transformed = Q1.T @ Si @ Q2
        axs[2, i].imshow(transformed, cmap='viridis')
        axs[2, i].axis('equal')
        axs[2, i].axis('off')
        axs[2, i].set_title(f'Transformed S[{i+1}]')

    plt.show()

    # Plot original matrices
    fig, axs = plt.subplots(3, len(S), figsize=(15, 6))

    #for i, Si in enumerate(Q1):
    axs[0, i].imshow(Q1, cmap='viridis')
    axs[0, i].axis('equal')
    axs[0, i].axis('off')
    axs[0, i].set_title(f'Q1')

    #for i, Si in enumerate(Q2):
    axs[1, i].imshow(Q2, cmap='viridis')
    axs[1, i].axis('equal')
    axs[1, i].axis('off')
    axs[1, i].set_title(f'Q2')

    plt.show()