import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# ])
R1 = []
R2 = []
EIG = []
for st in range(0,50000,1000):
    train_size = 500
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)
    # trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    #                                         download=True, transform=transform) 
    train_set = Subset(trainset, range(st,st+train_size))
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
                                            shuffle=True, num_workers=2)
    z = []
    for i, data in enumerate(trainloader, 0):
        # print(data[0].shape)
        y = data[1].float()
        y = torch.unsqueeze(y,dim=1)
        # print(y[1])
        # z[0] = torch.flatten(data[0][:,0,0:26,0:26],start_dim=1)
        # print(z1[1])
        # print(torch.mul(z1,y))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,0:26,0:26],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,0:26,1:27],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,0:26,2:28],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,1:27,0:26],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,1:27,1:27],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,1:27,2:28],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,2:28,0:26],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,2:28,1:27],start_dim=1),y),dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(data[0][:,0,2:28,2:28],start_dim=1),y),dim=0))
        z.append(torch.ones_like(torch.mean(torch.mul(torch.flatten(data[0][:,0,2:28,2:28],start_dim=1),y),dim=0))*torch.mean(y))
        # z[1] = torch.flatten(data[0][:,0,0:26,1:27],start_dim=1)
        # z[1] = torch.mean(torch.mul(z[1],y),dim=0)
        # z[2] = torch.flatten(data[0][:,0,0:26,2:28],start_dim=1)
        # z[2] = torch.mean(torch.mul(z[2],y),dim=0)
        # z[3] = torch.flatten(data[0][:,0,1:27,0:26],start_dim=1)
        # z[3] = torch.mean(torch.mul(z[3],y),dim=0)
        # z[4] = torch.flatten(data[0][:,0,1:27,1:27],start_dim=1)
        # z[4] = torch.mean(torch.mul(z[4],y),dim=0)
        # z[5] = torch.flatten(data[0][:,0,1:27,2:28],start_dim=1)
        # z[5] = torch.mean(torch.mul(z[5],y),dim=0)
        # z[6] = torch.flatten(data[0][:,0,2:28,0:26],start_dim=1)
        # z[6] = torch.mean(torch.mul(z[6],y),dim=0)
        # z[7] = torch.flatten(data[0][:,0,2:28,1:27],start_dim=1)
        # z[7] = torch.mean(torch.mul(z[7],y),dim=0)
        # z[8] = torch.flatten(data[0][:,0,2:28,2:28],start_dim=1)
        # z[8] = torch.mean(torch.mul(z[8],y),dim=0) 
    # print(torch.dot(z1,z1))
    # print(torch.dot(z1,z2))
    print(z[i].shape)
    M = torch.zeros(10,10)
    # print(M.shape)
    for i in range(10):
        for j in range(10):
            M[i,j] = torch.dot(z[i],z[j])
    # print(M)  
    Ma = np.array(M)
    print(Ma.shape)      
    M = torch.reshape(M,(-1,1))
    M = M/torch.norm(M)
    
    M = M.squeeze()
    O = torch.ones_like(M)
    O = O/torch.norm(O)
    # print(O.shape)
    # R.append(torch.dot(M,O))   
    innre = []
    eigval,eigvec = np.linalg.eig(Ma)
    EIG.append(eigval)
    print(eigval)
    # print(eigvec[:,1])
    # print(eigvec.shape)    
    One = np.ones_like(eigvec[:,0])
    One = One/np.linalg.norm(One)
    for i in range(10):
        eigvec[:,i] = eigvec[:,i]/np.linalg.norm(eigvec[:,i])
        innre.append(np.dot(eigvec[:,i],One))
    print(innre[0])
    eigveccut = eigvec[:,0][0:-1]
    eigveccut = eigveccut/np.linalg.norm(eigveccut)
    # eigveccut = eigveccut/np.max(eigveccut)
    print(eigveccut)
    eigveccut1 = eigvec[:,1][0:-1]
    eigveccut1 = eigveccut1/np.linalg.norm(eigveccut1)
    # eigveccut1 = eigveccut1/np.max(eigveccut1)
    print(eigveccut1)
    onecut = One[0:-1]
    onecut = onecut/np.linalg.norm(onecut)
    print(np.dot(eigveccut,onecut))
    print(np.dot(eigveccut1,onecut))
    print(np.dot(eigveccut,eigveccut1))
    R1.append(abs(np.dot(eigveccut,onecut)))
    R2.append(abs(np.dot(eigveccut1,onecut)))
plt.plot(R1,label='first')
plt.plot(R2,label='second')

plt.yticks([0.8,1.0,1.2])
plt.xlabel('trials',fontsize=20)
plt.ylabel('cosine similarity',fontsize=20)
plt.ylim(0.8,1.2)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.tight_layout()
plt.legend(fontsize=20)
plt.savefig('test500.png')

R1 = np.array(R1)
print(np.mean(R1),np.std(R1))

# plot eigenvalue of the dataset
# EIG = np.asarray(EIG)
# print(EIG)
# print(EIG.shape)
# sd = np.std(EIG,axis=0)
# # print(sd.shape)
# EIG_mean = np.mean(EIG,axis=0)
# # print(EIG_mean.shape)
# sd_e = sd/np.sqrt(50)
# print(sd_e)
# plt.errorbar(range(0, 10), EIG_mean, sd_e, fmt="o")
# plt.scatter(range(10),EIG_mean)
# plt.xticks(fontsize=20)
# plt.yticks(fontsize=20)
# plt.xlabel('index',fontsize=20)
# plt.ylabel('eigenvalue',fontsize=20)
# plt.yscale('log')
# plt.tight_layout()
# # plt.title(r'The eigenvalue of MNIST with convolution kernel is $3\times3$')
# # plt.yscale('log')
# plt.savefig('eig.png')
