
import scipy.io
import numpy as np
#from scipy.spatial.distance import pdist, squareform
import networkx as nx
from sklearn.preprocessing import StandardScaler
#import matplotlib.pyplot as plt
#from sklearn.decomposition import PCA
#from mpl_toolkits.mplot3d import Axes3D
#import seaborn as sns
#matplotlib inline
#import time
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D

import random

opdifferences = []
resdifferences = []
resdifferences2 = []
resdifferences3 = []
resdifferences4 = []
resdifferences5 = []
resdifferences6 = []



sinresdefs = []
expresdefs = []
netdifs = []
netdifs2 = []


for i in  range(1, 100, 1):
    ####
    s = i
    delta = 1 / s
    #delta = 0.0000001
    #s = 1 / delta

    print(s)

    A = np.array([[0, 4, 2,10],
                  [4, 0,17,9 ],
                  [2,17, 0,42],
                  [10, 9,42,0 ]])
    #print(A)
    # W =  s* np.array([[0, 6],
    # [6, 0]])

    W = s * np.array([[0, 6, 22,3],
                  [6, 0,1,90 ],
                  [22,1, 0,23],
                  [3, 90,23,0 ]])
    #print(W)

    Dsmall = np.diag(W.sum(axis=1))

    ## To the right of A
    C = np.array([[4, 5, 6,7],
                  [8, 9,10,11 ],
                  [12,13, 14,15],
                  [s*16, s*7,s*18,s*19 ]])

    Ct = C.transpose()

    Inter1 = np.concatenate((A, C), axis=1)
    Inter2 = np.concatenate((Ct, W), axis=1)
    FA = np.concatenate((Inter1, Inter2))

    # print(A)
    # print(W)

    # print(C)
    # print(Ct)
    # print(Inter1)
    # print(FA)

    # print(len(A))
    # print(C)

    # print(Dsmall)
    Lsmall = Dsmall - W
    # print(C[0])

    ######Define H
    H = Lsmall

    # print(invH)

    # print(len(Ct))
    # test = np.matmul(invH,C[0])
    # print(test)

    ######Construct Embedding vectors############
    for i in range(len(C)):
        #####recall len(A) = len(C)#############
        H = H + np.diag(C[i])

    invH = np.linalg.inv(H)
    psis = []

    for i in range(len(A)):
        psis.append(np.matmul(invH, C[i]))
    ##print(psis[0] + psis[1] + psis[2] + ....) = [1,..] ### test works
    ######################

    psis = np.array(psis)

    psis = np.concatenate((np.identity(len(A)), psis), axis=1)

    #######Create Weights########
    mus = []
    for i in range(len(A)):
        mus.append(np.linalg.norm(psis[i], ord=1))


    #######################

    # print(psis[0])
    # print(np.zeros(len([1,2,3])))
    ########### Let us now implement J###########

    def Jmap(x, w):
        target = np.zeros(w)
        for i in range(len(x)):
            target = target + x[i] * psis[i]
        return target


    J = np.zeros((len(FA), len(A)))
    for i in range(len(A)):
        J[:, i] = Jmap(np.identity(len(A))[:, i], len(FA))

    print('J')
    print(J)
    # here w is the dimension of the target space
    ############################################

    # print(Jmap([1,2,3],len(FA)))
    ############ Let us now create the respective operators ################
    #### the large one is given by####
    TildeLaplacian = np.diag(FA.sum(axis=1)) - FA
    print('TildeLaplacian')
    print(TildeLaplacian)
    ##### Now we need to create the small one:
    ntA = A
    # print(C)
    ntC = C
    ntC[len(A) - 1, :] = np.zeros(len(W))

    # print(len(C))
    # print(len(A))

    # print(ntA[0, :])

    # print(ntC[:,0])

    for i in range(len(W)):
        ntA[len(ntA) - 1, :] = ntA[len(ntA) - 1, :] + ntC[:, i]
        ntA[:, len(ntA) - 1] = ntA[:, len(ntA) - 1] + ntC[:, i]

    print('ntA')
    print(ntA)

    inversemus = np.linalg.inv(np.diag(mus))

    NotildeLaplacian = (np.diag(ntA.sum(axis=1)) - ntA)
    NotildeLaplacian = np.matmul(inversemus, NotildeLaplacian)

    print('NotildeLaplacian')
    print(NotildeLaplacian)

    DifferenceA = np.matmul(TildeLaplacian, J) - np.matmul(J, NotildeLaplacian)

    opdifferences.append(np.linalg.norm(DifferenceA, ord=2))


    Difference = np.matmul(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian), J) - np.matmul(J, np.linalg.inv(
        np.identity(len(A)) + NotildeLaplacian))

    Difference2 = np.matmul(np.linalg.matrix_power(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian),2), J) - np.matmul(J, np.linalg.matrix_power(np.linalg.inv(
        np.identity(len(A)) + NotildeLaplacian),2))

    Difference3 = np.matmul(np.linalg.matrix_power(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian), 3),
                            J) - np.matmul(J, np.linalg.matrix_power(np.linalg.inv(
        np.identity(len(A)) + NotildeLaplacian), 3))

    Difference4 = np.matmul(np.linalg.matrix_power(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian), 4),
                            J) - np.matmul(J, np.linalg.matrix_power(np.linalg.inv(
        np.identity(len(A)) + NotildeLaplacian), 4))

    Difference5 = np.matmul(np.linalg.matrix_power(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian), 5),
                            J) - np.matmul(J, np.linalg.matrix_power(np.linalg.inv(
        np.identity(len(A)) + NotildeLaplacian), 5))



    resdifferences.append(np.linalg.norm(Difference, ord=2))


    resdifferences2.append(np.linalg.norm(Difference2, ord=2))

    resdifferences3.append(np.linalg.norm(Difference3, ord=2))

    resdifferences4.append(np.linalg.norm(Difference4, ord=2))

    resdifferences5.append(np.linalg.norm(Difference5, ord=2))






    from scipy.linalg import sinm, cosm, expm

    sindiff = np.matmul(sinm(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian)), J) - np.matmul(J, sinm(
        np.linalg.inv(np.identity(len(A)) + NotildeLaplacian)))

    sinresdefs.append(np.linalg.norm(sindiff, ord=2))

    expdiff = np.matmul(expm(np.linalg.inv(np.identity(len(FA)) + TildeLaplacian)), J) - np.matmul(J, expm(
        np.linalg.inv(np.identity(len(A)) + NotildeLaplacian)))

    expresdefs.append(np.linalg.norm(expdiff, ord=2))

    print(np.linalg.norm(DifferenceA, ord=2))
    print(np.linalg.norm(Difference, ord=2))
    print(np.linalg.norm(sindiff, ord=2))
    print(np.linalg.norm(expdiff, ord=2))

    ################Let us consider a convolutional layer#################

    R1 = np.linalg.inv(np.identity(len(FA)) + TildeLaplacian)
    R2 = np.linalg.inv(np.identity(len(A)) + NotildeLaplacian)
    ## two to three###

    aa = np.matmul(sinm(R1), J)
    ab = np.matmul(cosm(R1), J)
    ac = np.matmul(expm(R1), J)
    ba = np.matmul(sinm(R1), J)
    bb = np.matmul(cosm(R1), J)
    bc = np.matmul(expm(R1), J)

    zaa = np.matmul(J, sinm(R2))
    zab = np.matmul(J, cosm(R2))
    zac = np.matmul(J, expm(R2))
    zba = np.matmul(J, sinm(R2))
    zbb = np.matmul(J, cosm(R2))
    zbc = np.matmul(J, expm(R2))

    Ka = np.concatenate((aa, ab), axis=0)
    Ka = np.concatenate((Ka, ac), axis=0)

    Kb = np.concatenate((ba, bb), axis=0)
    Kb = np.concatenate((Kb, bc), axis=0)

    K = np.concatenate((Ka, Kb), axis=1)

    KLa = np.concatenate((zaa, zab), axis=0)
    KLa = np.concatenate((KLa, zac), axis=0)

    KLb = np.concatenate((zba, zbb), axis=0)
    KLb = np.concatenate((KLb, zbc), axis=0)

    KL = np.concatenate((KLa, KLb), axis=1)
    netdif = np.linalg.norm(K - KL, ord=2)

    netdifs.append(netdif)
    print(np.linalg.norm(K - KL, ord=2))


    ############Network Two############

    R1 = np.linalg.inv(np.identity(len(FA)) + TildeLaplacian)
    R2 = np.linalg.inv(np.identity(len(A)) + NotildeLaplacian)
    ## two to six###

    aa = np.matmul(sinm(R1), J)
    ab = np.matmul(cosm(R1), J)
    ac = np.matmul(expm(R1), J)
    ba = np.matmul(sinm(R1), J)
    bb = np.matmul(cosm(R1), J)
    bc = np.matmul(expm(R1), J)

    zaa = np.matmul(J, sinm(R2))
    zab = np.matmul(J, cosm(R2))
    zac = np.matmul(J, expm(R2))
    zba = np.matmul(J, sinm(R2))
    zbb = np.matmul(J, cosm(R2))
    zbc = np.matmul(J, expm(R2))

    Ka = np.concatenate((aa, ab), axis=0)
    Ka = np.concatenate((Ka, ac), axis=0)
    Ka = np.concatenate((Ka, np.matmul(np.zeros((len(FA),len(FA))),J)), axis=0)
    Ka = np.concatenate((Ka, np.matmul(np.zeros((len(FA),len(FA))),J)), axis=0)
    Ka = np.concatenate((Ka, np.matmul(np.zeros((len(FA),len(FA))),J)), axis=0)

    Kb = np.concatenate((np.matmul(np.zeros((len(FA),len(FA))),J), np.matmul(np.zeros((len(FA),len(FA))),J)), axis=0)
    Kb = np.concatenate((Kb, np.matmul(np.zeros((len(FA),len(FA))),J)), axis=0)
    Kb = np.concatenate((Kb, ba), axis=0)
    Kb = np.concatenate((Kb, bb), axis=0)
    Kb = np.concatenate((Kb, bc), axis=0)


    K = np.concatenate((Ka, Kb), axis=1)

    KLa = np.concatenate((zaa, zab), axis=0)
    KLa = np.concatenate((KLa, zac), axis=0)
    KLa = np.concatenate((KLa, np.matmul(J,np.zeros((len(A),len(A))))), axis=0)
    KLa = np.concatenate((KLa, np.matmul(J,np.zeros((len(A),len(A))))), axis=0)
    KLa = np.concatenate((KLa, np.matmul(J,np.zeros((len(A),len(A))))), axis=0)

    KLb = np.concatenate((np.matmul(J,np.zeros((len(A),len(A)))), np.matmul(J,np.zeros((len(A),len(A))))), axis=0)
    print(np.shape(KLb))
    KLb = np.concatenate((KLb, np.matmul(J,np.zeros((len(A),len(A))))), axis=0)
    print(np.shape(KLb))
    KLb = np.concatenate((KLb, zba), axis=0)
    print(np.shape(KLb))
    KLb = np.concatenate((KLb, zbb), axis=0)
    print(np.shape(KLb))
    KLb = np.concatenate((KLb, zbc), axis=0)
    print(np.shape(KLb))



    KL = np.concatenate((KLa, KLb), axis=1)
    netdif2 = np.linalg.norm(K - KL, ord=2)

    netdifs2.append(netdif2)
    print(np.linalg.norm(K - KL, ord=2))

    ######sort of done#####



import matplotlib.pyplot as plt







#########Operator Differences########

plt.plot(opdifferences)
plt.ylabel('Operator Differences')
plt.show()



#########Resolvent Differences########

plt.plot(resdifferences)
plt.ylabel('Resolvent Differences')
plt.show()



#########Resolvent monomial Differences########

plt.plot(resdifferences, label='k=1')
plt.plot(resdifferences2,label='k=2')
plt.plot(resdifferences3, label='k=3')
#plt.plot(resdifferences4)
#plt.plot(resdifferences5)
plt.ylabel('Resolvent Monomial Differences')
plt.legend()
plt.show()




#########Sin Differences########

plt.plot(sinresdefs)
plt.ylabel('Sin Differences')
plt.show()


#########Exp Differences########

plt.plot(expresdefs)
plt.ylabel('Exp Differences')
plt.show()



#####Output Differences###########
plt.plot(netdifs, label='Convolutional Layer I')
plt.plot(netdifs2, label='Convolutional Layer II')
plt.ylabel('Convolutional Layer output differences')
plt.legend()
plt.show()

print(netdifs)
print(netdifs2)