from itertools import permutations

import scipy.sparse as sps

import numpy as np


def symbasis(n):
    """
    The orthogonal basis matrix of Sn, the size-n symmetric matrices Sn
    """
    ndof = n * (n + 1) // 2

    # Enumerate the subscripts of a lower triangular matrix
    jj, ii = np.triu_indices(n)

    # Convert the subscripts into indices
    ii1 = ii + jj * n
    ii2 = ii * n + jj

    # Number according to DOFs
    jj = np.arange(ndof)

    # Scaling to make the matrix unitary
    kk = np.sqrt(0.5) * np.ones((ndof, 2))
    kk[ii1 == ii2] = 0.5

    ii = np.hstack([ii1, ii2])
    jj = np.hstack([jj, jj])

    V = np.zeros((n ** 2, ndof))
    V[ii, jj] += kk.flatten('F')

    return V


def spcomms(m, n):
    row = np.arange(m * n)
    col = row.reshape((m, n), order='F').ravel()
    data = np.ones(m * n, dtype=np.int8)
    K = sps.csr_array((data, (row, col)), shape=(m * n, m * n))

    return K


def comms_invperm(m, n):
    """
    Returns the inverse permutation slot of commutation matrix K(m, n)
    """
    inv_perm = np.arange(m * n).reshape((n, m), order='F').ravel()

    return inv_perm


def coo_perm_col(A, perm):
    """
    Permute the columns of a sparse COO array
    """
    A_perm = sps.coo_matrix((A.data, (A.row, perm[A.col])), shape=A.shape)

    return A_perm


def error_rate_perm(pred_label, true_label):
    unique_label = np.unique(true_label)
    n = true_label.size
    label_perm = []
    err_rate = []
    for p in permutations(unique_label.tolist()):
        p = np.array(p)
        label_perm.append(p)
        err_rate.append((np.array(p)[pred_label] != true_label).sum() / n)

    err_rate = np.array(err_rate)
    min_id = err_rate.argmin()

    return (err_rate[min_id].item(), label_perm[min_id])
