"""vgg in pytorch


[1] Karen Simonyan, Andrew Zisserman

    Very Deep Convolutional Networks for Large-Scale Image Recognition.
    https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''

from unittest import result
import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
import joblib
import os
import time
cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

        #These variables are used for collecting Kmeans_data
        self.kmeans_collect_enable = False
        self.kmeans_data = []
        self.kmeans = None
        #These variables are used for calculating the taylor expansion related parameters
        self.x0 = None
        self.fx0 = None
        self.gradient = None
        self.taylor_enable = False
        #These variables are used for calculating the mean square error 
        #of the original output and the taylor output
        self.mse_enable = False
        self.mse = 0.0
        self.mse_per_dim = torch.zeros(100)
        self.total_per_dim = torch.zeros(100)
        #These variables are used for record the running time
        self.timer_enbale = False
        self.classifier_time = 0.0
        self.kmeans_predict_time = 0.0
        self.taylor_forward_time = 0.0
        self.conv_time = 0.0

        self.mean_enable = False
        self.result_list = []
        self.taylor_result_list = []

    def reset(self):
        self.classifier_time = 0.0
        self.kmeans_predict_time = 0.0
        self.taylor_forward_time = 0.0
        self.conv_time = 0.0
        self.mse = 0.0

        self.n_job = 1
        self.result_list = []
        self.taylor_result_list = []
        

    def forward(self, x):
        if self.timer_enbale:
            conv_start = time.perf_counter()
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        # print(output.shape)
        if self.timer_enbale:
            conv_end = time.perf_counter()
            self.conv_time += conv_end - conv_start
   
        if self.kmeans_collect_enable:
            output_np = output.detach().cpu().numpy()
            for i in range(output_np.shape[0]):
                self.kmeans_data.append(output_np[i])
        
        if self.taylor_enable:
            result = self.Taylor_Forward(output)
            if self.mean_enable:
                result_np = result[:,:5].detach().numpy()
                for i in range(len(result_np)):
                    self.result_list.append(result_np[i])
            if self.mse_enable:
                result_original = self.classifier(output)
                self.mse += ((result_original - result) ** 2).sum(dim=1).sum()
                if self.mean_enable:
                    result_original_np = result_original[:,:5].detach().numpy()
                    for i in range(len(result_original_np)):
                        self.taylor_result_list.append(result_original_np[i])
        else:
            start = time.perf_counter()
            result = self.classifier(output)
            end = time.perf_counter()
            self.classifier_time += end - start
            #print('classifier time: ',end-start)
        
        return result
    def Taylor_Forward(self,x):
    
        start = time.perf_counter()
        # cluster_index = self.kmeans.predict(x_numpy)
        cluster_index = torch.argmin(torch.norm(self.x0 - x, dim=1))
        end = time.perf_counter()
        self.kmeans_predict_time += end - start

        
        x0 = self.x0[cluster_index]
        fx0 = self.fx0[cluster_index]
        gradient = self.gradient[cluster_index]
        start = time.perf_counter()
        # delta = x - x0
        # taylor_term = torch.bmm(delta.unsqueeze(1), gradient).squeeze(1)
        # result = fx0 + taylor_term
        delta = x - self.x0[cluster_index]
        # taylor_term = torch.bmm(delta.unsqueeze(1), self.gradient[cluster_index[0]]).squeeze(1)
        taylor_term = delta @ self.gradient[cluster_index] 
        result = self.fx0[cluster_index] + taylor_term
        
        end = time.perf_counter()
        self.taylor_forward_time +=end-start
        return result
    
    def Taylor_Learn(self, save_dir):
        self.x0 = torch.tensor(self.kmeans.cluster_centers_) 
        fx_list = []
        for i in range(len(self.x0)):
            # print(self.classifier(self.x0[i]))
            fx_list.append(self.classifier(self.x0[i]))
        self.fx0 = torch.stack(fx_list)
            
        gradient_list = []
        for i in range(len(self.x0)):
            x = self.x0[i].clone().detach().requires_grad_(True)
            fx = self.classifier(x)
            # 对每个输出维度计算梯度
            grad_list = []
            for j in range(fx.shape[0]):  # 遍历输出向量的每个维度
                fx_j = fx[j]
                grad = torch.autograd.grad(fx_j, x,  create_graph=True)[0]
                grad_list.append(grad)
            # 将所有维度的梯度堆叠起来
            gradient = torch.stack(grad_list)
            gradient = gradient.T
            gradient_list.append(gradient)
        self.gradient = torch.stack(gradient_list)

        os.makedirs(save_dir, exist_ok=True)
        k = self.kmeans.n_clusters
        # 保存数据
        torch.save(self.x0.detach(), os.path.join(save_dir, f'x0_k{k}.pt'))
        torch.save(self.fx0.detach(), os.path.join(save_dir, f'fx0_k{k}.pt'))
        torch.save(self.gradient.detach(), os.path.join(save_dir, f'gradient_k{k}.pt'))
        print(f"已保存数据到: {save_dir}")
     
    def Taylor_Load(self,save_path,k):
        self.x0 = torch.load(os.path.join(save_path,f'x0_k{k}.pt'))    
        self.fx0 = torch.load(os.path.join(save_path,f'fx0_k{k}.pt'))
        self.gradient = torch.load(os.path.join(save_path,f'gradient_k{k}.pt')) 
        self.x0 = self.x0.contiguous()
        self.fx0 = self.fx0.contiguous()
        self.gradient = self.gradient.contiguous()
    def Load_Kmeans(self,filename):
        self.kmeans = joblib.load(filename)
    
    def Save_Kmeans_Data(self, filename):
        self.kmeans_data = np.array(self.kmeans_data)
        np.save(filename, self.kmeans_data)


def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg11_bn():
    return VGG(make_layers(cfg['A'], batch_norm=True))

def vgg13_bn():
    return VGG(make_layers(cfg['B'], batch_norm=True))

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))

def vgg19_bn():
    return VGG(make_layers(cfg['E'], batch_norm=True))


