


"""
Functions which have been removed from the submission but may still be useful.
"""



def check_kcon(A:np.array, B:np.array, C:np.array, g:float, k:int):
    """ 
    Function for verifying the system is k-contracting according to Theorem 2 - 
    Contraction and k-contraction in Lurie systems with applications to networked 
    systems; Ofir, et al.
    args:
            A: Lurie system parameter.
            B: Lurie system parameter.
            C: Lurie system parameter.
            g: Upper bound on slope of nonlinearity.
            k: k-contraction value.
    returns:
            P: contraction metric.
    """
    
    n, m = B.shape

    # calculate alpha_k
    A_sym = A + np.transpose(A)
    A_eigs = np.sort( np.linalg.eigvals(A_sym) )[::-1]
    alpha_k = (1 / 2*k ) * np.sum( A_eigs[0:k] )

    # calculate LHS of inequality (3)
    B_sing = np.sort( sp.linalg.svdvals(B) )[::-1]
    C_sing = np.sort( sp.linalg.svdvals(C) )[::-1]
    B_sing2 = B_sing * B_sing
    C_sing2 = C_sing * C_sing
    LHS = g * g * np.sum( B_sing2[0:k] * C_sing2[0:k] )

    # calculate RHS of inequality (3)
    RHS = alpha_k * alpha_k * k

    # calculate P
    P = -(1/alpha_k)*np.eye(n)

    print("Check 1: Is alpha_k < 0 ?")
    print("Check 2: Is LHS < RHS ?")
    print("System is k-contracting if both are true!")

    print("alpha_k: ", alpha_k)
    print("LHS: ", LHS)
    print("RHS: ", RHS)
    print("A eigs: ", A_eigs/2)
    print("B sing2: ", B_sing2)
    print("C sing2: ", C_sing2)
    
    return P



def compare_components(loc:str, file:str, title:str, T:np.array, X:np.array, q:int, y_lim=None):
    """
    Compare individual components of trajectories from dataset X.
    args:
      loc: location to save figure.
     file: name to save figure as.
    title: title of figure.
        T: Time points.
        X: Dataset.
        q: Number of trajectories to compare. 
    y_lim: limits on the y axis plots.
    """

    tmax, N, n = X.shape

    # randomly pick a batch of trajectories to plot.
    a = np.random.randint(N, size=q)

    fig, axs = plt.subplots(q,3)
    fig.suptitle(title)

    for i in range(q):
        for j in range(3):
                axs[i,j].plot(T, X[:,a[i],j])
                if j==0:
                    axs[i,j].set(ylabel='Traj. {0}'.format(a[i]))

    
    axs[0,0].set_title('x1')
    axs[0,1].set_title('x2')
    axs[0,2].set_title('x3')

    if y_lim != None:
        for i in range(q):
            for j in range(3):
                axs[i,j].set_ylim(y_lim[i][j])

    fig.savefig(loc + file + '.png')
    plt.show()

    return 0



def compare_components(loc:str, file:str, T:torch.tensor, X:torch.tensor, X2:list, names:list, q:int, y_lim=None):
    """
    Compare individual components of trajectories from dataset X against predictions X2.
    This will only work for 3d data.
    args:
      loc: location to save figure.
     file: Name to save figure as.
        T: Time points.
        X: Dataset 1.
       X2: List containing predictions from a set of models.
    names: list of the model names
        q: Number of trajectories to compare.
    y_lim: limits on the y axis plots.
    """

    N = len(names)
    batch, bs, tmax, n = X.shape

    # randomly pick a batch and trajectory from that batch to compare.
    a = torch.randint(0, batch, (q,))
    b = torch.randint(0, bs, (q,))
    
    fig, axs = plt.subplots(q,3, figsize=(6, 6))
    fig.tight_layout()

    line = [[[None] * (N+1)] * 3] * q
    

    for i in range(q):
        for j in range(3):
            line[i][j][N], = axs[i,j].plot(T, X[a[i],b[i],:,j], linewidth=2.0, linestyle='dashed')
            for k in range(N):
                if k==0:
                    line[i][j][k], = axs[i,j].plot(T, X2[k][b[i],:,j], linestyle='dashed')
                elif k==1:
                    line[i][j][k], = axs[i,j].plot(T, X2[k][b[i],:,j], linestyle='dotted')
                elif k==2:
                    line[i][j][k], = axs[i,j].plot(T, X2[k][b[i],:,j], linestyle='dashdot')
                if j==0:
                    # axs[i,j].set(ylabel='batch {0}, traj. {1}'.format(a[i], b[i]))
                    axs[i,j].set(ylabel='traj. {0}'.format(b[i]))

    axs[0,0].legend([line[0][0][N]], ['True'])
    for i in range(N):
        axs[1,i].legend([line[0][0][i]], [names[i]])
    
    axs[0,0].set_title('x1')
    axs[0,1].set_title('x2')
    axs[0,2].set_title('x3')

    if y_lim != None:
        for i in range(q):
            for j in range(3):
                axs[i,j].set_ylim(y_lim[i][j])

    plt.subplots_adjust(left=0.1)
    
    fig.savefig(loc + file +'.png')
    plt.show()

    return 0





