import torch
import tqdm
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData
from lib.dataset.mydata import Cifar100
from torch.utils.data import DataLoader
from lib.model.resnet_nbn import ResNet34, ResNet18
from lib.dataset.imagenet import ImageNet, Indata
from lib.dataset.myset import MyDataset
from lib.model.cifarnet import ImgNet as Net
import torchvision.models as models
from lib.model.resnext_nbn import ResNeXt29_8x64d, ResNeXt29_2x64d
from lib.model.densenet import DenseNet121
import torch.optim as optim
import time
from lib.util.mytoolbag import setup_seed
from lib.util.t426_train_net import train_net_o


criterion = nn.CrossEntropyLoss()
SIZE = 30


def cal_ntk_c(now_set, net, optimizer, tot_para):
    ntk_loader = DataLoader(now_set, batch_size=1)
    vectors = []
    tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:

    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vectors.append(get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para))
        # p_bar.update(1)
    siz = len(vectors)
    mat = np.zeros((siz, siz), dtype=float)
    # print('get H matrix')
    #p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    for i in range(siz):
        for j in range(i, siz):
            mat[i][j] = mat[j][i] = vectors[i].dot(vectors[j])
            #p_bar.update(1)
    print(mat)
    czz = np.min(np.linalg.eigvalsh(mat))
    return czz


def cal_tr(now_set, net, optimizer, tot_para, p_bar=None):
    ntk_loader = DataLoader(now_set, batch_size=1)
    # vectors = []
    # tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:
    tmp = 0
    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(0, 10):
            vec = get_gradient_tensor(net, optimizer, outputs[k], tot=tot_para)
            tmp += vec.dot(vec)
        # p_bar.update(1)
    # siz = len(vectors)
    # print('get H matrix')
    # p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    # for i in range(siz):
    #     tmp += float(vectors[i].dot(vectors[i]))
            #p_bar.update(1)
    return tmp


def round1(net, now_set, p_bar=None):
    # setup_seed(4)
    tot_para = cal_para(net)
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # data.train_loader(batch=1000)

    czz = 0 # cal_ntk_c(now_set, net, optimizer, tot_para)

    tr = cal_tr(now_set, net, optimizer, tot_para, p_bar=p_bar)

    return czz, tr


def main():
    data = ImageNet()
    # f1 = open('logs/record1/ntk_test_cifar_s10.txt', 'a')
    # f2 = open('logs/record1/tr_cifar_s1.txt', 'a')
    # p_bar = tqdm.tqdm(total=len(data.train_set))
    # net = DenseNet121(num_cls=100).cuda()
    net = Net(num_cls=10).cuda()
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    # train_net_o(data.train_loader(batch=100), net, optimizer, data.train_loader(data_set=data.test_set, batch=500),
    #             epoch1=50, scheduler=scheduler)
    p_bar = tqdm.tqdm(total=len(data.train_pic))
    for i in range(len(data.train_pic)):
        pic, lab = data.train_pic[i], data.train_lab[i]
        czz, tr = round1(net, Indata([pic], [lab], data.train_transform))
        # czz, tr = round1(data, MyDataset([pic], [lab]), p_bar)
        # f1.write(str(czz) + '\n')
        f2 = open('logs/record1/tr_imagenet_fnn.txt', 'a')
        f2.write(str(float(tr)) + '\n')
        f2.close()
        p_bar.update(1)
        # p_bar.update(1)
    # f1.close()
    # f2.close()



if __name__ == '__main__':
    main()


