import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_size = 500
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
# train_set = Subset(trainset, range(train_size))

# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# ])
R1 = []
R11 = []
R12 = []
R13 = []
R2 = []
EIG = []
for st in range(0, 50000,1000):
    train_size = 500
    # trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    #                                         download=True, transform=transform)
    train_set = Subset(trainset, range(st,st+train_size))
    # trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    #                                         download=True, transform=transform)
    # train_set = Subset(trainset, range(st, st+train_size))
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
                                              shuffle=True, num_workers=2)
    z = []
    for i, data in enumerate(trainloader, 0):
        # print(data[0].shape)
        y = data[1].float()
        y = torch.unsqueeze(y, dim=1)
# channel 0
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 0:30, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 0:30, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 0:30, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 1:31, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 1:31, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 1:31, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 2:32, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 2:32, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 2:32, 2:32], start_dim=1), y), dim=0))

# channel 1
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 0:30, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 0:30, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 0:30, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 1:31, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 1:31, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 1:31, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 2:32, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 2:32, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 1, 2:32, 2:32], start_dim=1), y), dim=0))

# channel 2
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 0:30, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 0:30, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 0:30, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 1:31, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 1:31, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 1:31, 2:32], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 2:32, 0:30], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 2:32, 1:31], start_dim=1), y), dim=0))
        z.append(torch.mean(torch.mul(torch.flatten(
            data[0][:, 2, 2:32, 2:32], start_dim=1), y), dim=0))
        
# bias 
        z.append(torch.ones_like(torch.mean(torch.mul(torch.flatten(
            data[0][:, 0, 2:32, 2:32], start_dim=1), y), dim=0))*torch.mean(y))
    print(z[i].shape)
    M = torch.zeros(28,28)
    for i in range(28):
        for j in range(28):
            M[i, j]=torch.dot(z[i], z[j])
    Ma = np.array(M)
    print(Ma.shape)
    M = torch.reshape(M, (-1, 1))
    M = M/torch.norm(M)
    M = M.squeeze()
    O = torch.ones_like(M)
    O = O/torch.norm(O)
    # print(O.shape)
    # R.append(torch.dot(M,O))
    innre = []
    eigval, eigvec = np.linalg.eig(Ma)
    EIG.append(eigval)
    print(eigval)
    One = np.ones_like(eigvec[:, 0])
    One = One/np.linalg.norm(One)
    for i in range(10):
        eigvec[:, i] = eigvec[:, i]/np.linalg.norm(eigvec[:, i])
        innre.append(np.dot(eigvec[:, i], One))
    # print(innre[0])
    eigveccut = eigvec[:, 0][0:-1]
    eigveccut = eigveccut/np.linalg.norm(eigveccut)
    
    eigveccut11 = eigveccut[0:9]
    eigveccut11 = eigveccut11/np.linalg.norm(eigveccut11)

    eigveccut12 = eigveccut[9:18]
    eigveccut12 = eigveccut12/np.linalg.norm(eigveccut12)

    eigveccut13 = eigveccut[18:]
    eigveccut13 = eigveccut13/np.linalg.norm(eigveccut13)
    # eigveccut = eigveccut/np.max(eigveccut)
    # print(eigveccut)
    eigveccut1 = eigvec[:, 1][0:-1]
    eigveccut1 = eigveccut1/np.linalg.norm(eigveccut1)
    # eigveccut1 = eigveccut1/np.max(eigveccut1)
    # print(eigveccut1)
    onecut = One[0:-1]    
    onecut = onecut/np.linalg.norm(onecut)
    onecut9 = onecut[0:9]
    onecut9 = onecut9/np.linalg.norm(onecut9)
    print(np.dot(eigveccut, onecut))
    # print(np.dot(eigveccut1, onecut))
    # print(np.dot(eigveccut, eigveccut1))
    R1.append(abs(np.dot(eigveccut, onecut)))
    R11.append(abs(np.dot(eigveccut11,onecut9)))
    R12.append(abs(np.dot(eigveccut12,onecut9)))
    R13.append(abs(np.dot(eigveccut13,onecut9))) 
    R2.append(abs(np.dot(eigveccut1, onecut)))
# whole vector
plt.plot(R1, label='first')
plt.plot(R2, label='second')
# split vector into 3 channels
# plt.plot(R11,label='channel 1')
# plt.plot(R12,label='channel 2')
# plt.plot(R13,label='channel 3')
# print(R1)
# print(R2)
plt.yticks([0.6,1.0,1.4])
plt.xlabel('trials', fontsize=20)
plt.ylabel('cosine similarity', fontsize=20)
plt.ylim(0.6, 1.4)

plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.tight_layout()
plt.legend(fontsize=18)
plt.savefig('test500cifar10.png')
plt.close()

R11 = np.array(R11)
R12 = np.array(R12)
R13 = np.array(R13)
print(R11)
# print(np.mean(R11),np.std(R11))
# print(np.mean(R12),np.std(R12))
# print(np.mean(R13),np.std(R13))

# plot eigenvalue of the dataset
# EIG = np.asarray(EIG)
# print(EIG)
# # print(EIG.shape)
# sd = np.std(EIG,axis=0)
# # print(sd.shape)
# EIG_mean = np.mean(EIG,axis=0)
# # print(EIG_mean.shape)
# sd_e = sd/np.sqrt(50)
# print(sd_e)
# # plt.errorbar(range(0, 10), EIG_mean, sd_e, fmt="o")
# plt.scatter(range(15),EIG_mean[:15])
# plt.xticks(fontsize=20)
# plt.yticks(fontsize=20)
# plt.xlabel('index',fontsize=20)
# plt.ylabel('eigenvalue',fontsize=20)
# plt.yscale('log')
# plt.tight_layout()
# # plt.title(r'The eigenvalue of MNIST with convolution kernel is $3\times3$')
# # plt.yscale('log')
# plt.savefig('eigCIFAR10.png')
