import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tool import loaddata, makedir
from torch.utils.data import Subset

savedir = '/home/***/data/undergraky/experiment/Cifar10/onehotMSE/230422140417tanh'
data = loaddata(os.path.join(savedir,'trainpro.pkl'))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# MNIST
# train_size = 500
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.1307,), (0.3081,))
# ])
# trainset = torchvision.datasets.MNIST(root='./data', train=True,
#                                         download=True, transform=transform)
# train_set = Subset(trainset, range(1000,1000+train_size))

# trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
#                                           shuffle=False, num_workers=2)

# CIFAR10
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_size = 500
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_set = Subset(trainset, range(train_size))

trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
                                          shuffle=True, num_workers=2)
z = []
for i, data in enumerate(trainloader, 0):
    # print(data[0].shape)
    y = data[1].float()
    y = torch.unsqueeze(y,dim=1)
    # print(y[1])
    # z[0] = torch.flatten(data[0][:,0,0:26,0:26],start_dim=1)
    # print(z1[1])
    # print(torch.mul(z1,y))
    # channel 0
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 0:30, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 0:30, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 0:30, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 1:31, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 1:31, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 1:31, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 2:32, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 2:32, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 2:32, 2:32], start_dim=1), y), dim=0))

    # channel 1
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 0:30, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 0:30, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 0:30, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 1:31, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 1:31, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 1:31, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 2:32, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 2:32, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 1, 2:32, 2:32], start_dim=1), y), dim=0))

    # channel 2
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 0:30, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 0:30, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 0:30, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 1:31, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 1:31, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 1:31, 2:32], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 2:32, 0:30], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 2:32, 1:31], start_dim=1), y), dim=0))
    z.append(torch.mean(torch.mul(torch.flatten(
        data[0][:, 2, 2:32, 2:32], start_dim=1), y), dim=0))

    # bias
    z.append(torch.ones_like(torch.mean(torch.mul(torch.flatten(
        data[0][:, 0, 2:32, 2:32], start_dim=1), y), dim=0))*torch.mean(y))
M = torch.zeros(28,28)
    # print(M.shape)
for i in range(28):
    for j in range(28):
        M[i,j] = torch.dot(z[i],z[j])
# print(M)  
Ma = np.array(M)
# print(Ma.shape)      
M = torch.reshape(M,(-1,1))
M = M/torch.norm(M)

M = M.squeeze()
O = torch.ones_like(M)
O = O/torch.norm(O)
# print(O.shape)
# R.append(torch.dot(M,O))   
innre = []
eigval,eigvec = np.linalg.eig(Ma)
# kernel vector of the largest eignevalue 
tarvec = eigvec[:,0]
# print(tarvec)
tarvec = np.real(tarvec).astype(float)
print(isinstance(tarvec[0],float))
tarvec = tarvec/np.linalg.norm(tarvec)
tarvec = torch.tensor(tarvec)
# tarvec = tarvec.to(device)
# print(eigval[0]/eigval[1])
#kernel vector of the second largest eigenvalue
tarvec1 = eigvec[:,1]
tarvec1 = np.real(tarvec1).astype(float)
tarvec1 = tarvec1/np.linalg.norm(tarvec1)
tarvec1 = torch.tensor(tarvec1)
# tarvec1 = tarvec1.to(device) 
# savedir = '/home/***/data/undergraky/condenseMNIST/230305111633tanh'
data = loaddata(os.path.join(savedir,'trainpro.pkl'))

scaler1 = []
scaler2 = []
scaleres2 = []
ratior = []
f = lambda t: np.exp((np.sqrt(eigval[0])-np.sqrt(eigval[1]))*t)
fr = []
for i in range(0,100):
    epoch = i * 100
    model_dict = torch.load(os.path.join(savedir,'epoch=%s.pt'%epoch))

    weight = model_dict['conv1.weight']
    # print(weight.shape)
    # weight = torch.flatten(weight,start_dim=1,end_dim=1)
    weight = torch.squeeze(weight)
    weight = torch.flatten(weight, start_dim=1,end_dim=3)

    bias = model_dict['conv1.bias']
    bias = torch.unsqueeze(bias,-1)

    vector = torch.hstack((weight,bias))    
    vector = vector.cpu()
    vectorclone = vector.clone().detach()
    vectorclone = torch.tensor(vectorclone,dtype=float)
    # vector = weight
    # norm = torch.unsqueeze(torch.norm(vector,dim=1),-1)

# draw heat map of the kernel weights
    # vector = vector/norm
    # rankvector = torch.matmul(vector,tarvec)
    # index = torch.argsort(rankvector)

    # vector = vector[index]
    # heat = torch.matmul(vector,vector.T)
    # heat = heat.cpu().detach().numpy()
    # print(heat)
    # plt.scatter(range(len(vector[:,-1].cpu().detach().numpy())),abs(vector[:,-1].cpu().detach().numpy()))
    # plt.savefig(os.path.join(savedir,'bdis%s.png'%epoch))
    # plt.close()

    # plt.pcolormesh(heat,vmin=-1,vmax=1)
    # plt.savefig(os.path.join(savedir,'vecbiaskernel%s.png'%epoch))
    # # plt.colorbar()
    # # plt.legend()
    # plt.close()

# ratio on different eigenvectors    
    scale1 = torch.matmul(vectorclone,tarvec)
    scale2 = torch.matmul(vectorclone,tarvec1)
    scale1 = torch.norm(scale1)
    scale2 = torch.norm(scale2)
    ratio = scale1/scale2
    ratior.append(ratio)
    scaler1.append(scale1)
    scaler2.append(scale2)
    fr.append(f(epoch*5e-6))

# 向量减去其最大特征值对应特征向量的方向
    scale1 = torch.matmul(vectorclone,tarvec)
    scale1s = torch.unsqueeze(scale1,dim=1)
    # print(scale1s.shape)
    # print(vector.shape)
    vector = vector - scale1s * vector
    # print(vector.shape)
    scale2res = torch.matmul(vectorclone,tarvec1)
    scale2res = torch.norm(scale2)
    scaleres2.append(scale2)



fig,ax1 = plt.subplots()
ax1.set_xlabel('time (epoch/100)',fontsize=20)
ax1.set_ylabel(r'$||\theta_{W,v_i}||$',fontsize=20)

ax1.plot(scaler1,label=r'$P_1$')
ax1.plot(scaler2,label=r'$P_2$')
ax1.tick_params(axis='y',which='minor',labelsize='20')
ax1.tick_params(axis='y',which='major',labelsize='20')
# ax1.legend()
# plt.plot(scaleres2,label=r'res $Pro_2$')
# plt.axhline(y=1,ls='--')
# # plt.plot(fr,label='th')
ax2 = ax1.twinx()
ax2.set_ylabel(r'$\frac{||\theta_{W,v_1}||}{||\theta_{W,v_2}||}$',fontsize=20)
ax2.plot(ratior,label=r'$\frac{P_1}{P_2}$',c='g')
ax2.tick_params(axis='y',which='both',labelsize=20)
# ax2.ylabel(rotation='0')
plt.tight_layout()
fig.legend(loc='upper left',bbox_to_anchor=(0.17,0.98))
plt.savefig('3CIFAR10.png')
