import torch
import tqdm
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData
from lib.dataset.mydata import Cifar100
from torch.utils.data import DataLoader
from lib.model.resnet_nbn import ResNet34, ResNet18
from lib.dataset.myset import MyDataset
from lib.model.cifarnet import Net
import torchvision.models as models
from lib.model.resnext_nbn import ResNeXt29_8x64d, ResNeXt29_2x64d
from lib.model.densenet import DenseNet121
import torch.optim as optim
import time
import argparse
from lib.util.mytoolbag import setup_seed
from lib.model.vgg import VGG
from lib.util.t426_train_net import train_net_o


criterion = nn.CrossEntropyLoss()
SIZE = 30


def cal_ntk_c(now_set, net, optimizer, tot_para):
    ntk_loader = DataLoader(now_set, batch_size=1)
    vectors = []
    tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:

    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vectors.append(get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para))
        # p_bar.update(1)
    siz = len(vectors)
    mat = np.zeros((siz, siz), dtype=float)
    # print('get H matrix')
    #p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    for i in range(siz):
        for j in range(i, siz):
            mat[i][j] = mat[j][i] = vectors[i].dot(vectors[j])
            #p_bar.update(1)
    print(mat)
    czz = np.min(np.linalg.eigvalsh(mat))
    return czz


def cal_tr(now_set, net, optimizer, tot_para, p_bar=None):
    ntk_loader = DataLoader(now_set, batch_size=1)
    # vectors = []
    # tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:
    tmp = 0
    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vec = get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para)
            tmp += vec.dot(vec)
        # p_bar.update(1)
    # siz = len(vectors)
    # print('get H matrix')
    # p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    # for i in range(siz):
    #     tmp += float(vectors[i].dot(vectors[i]))
            #p_bar.update(1)
    return tmp


def round1(net, now_set, p_bar=None, args=None):
    setup_seed(args.seed)
    tot_para = cal_para(net)
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # data.train_loader(batch=1000)

    czz = 0 # cal_ntk_c(now_set, net, optimizer, tot_para)

    tr = cal_tr(now_set, net, optimizer, tot_para, p_bar=p_bar)

    return czz, tr


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-md', default='?', type=str)
    parser.add_argument('-m', default='t', type=str)
    parser.add_argument('-seed', default=1, type=int)
    args = parser.parse_args()
    print(args)
    data = CifarData()
    # f1 = open('logs/record1/ntk_test_cifar_s10.txt', 'a')
    # f2 = open('logs/record1/tr_cifar_s1.txt', 'a')
    # p_bar = tqdm.tqdm(total=len(data.train_set))
    # net = DenseNet121(num_cls=10).cuda()
    # net = Net(num_cls=10).cuda()
    net = Net().cuda()
    if args.md == 'r':
        net = ResNet18().cuda()
    elif args.md == 'v':
        net = VGG('VGG16').cuda()
    elif args.md == 'x':
        net = ResNeXt29_2x64d().cuda()
    elif args.md == 'd':
        net = DenseNet121().cuda()
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    # train_net_o(data.train_loader(batch=100), net, optimizer, data.train_loader(data_set=data.test_set, batch=500),
    #             epoch1=50, scheduler=scheduler)
    bg = time.time()
    if args.m == 'l':
        lim = 300
    coin = 0
    for item in data.train_set:
        pic, lab = item
        czz, tr = round1(net, MyDataset([pic], [lab]), args=args)
        # czz, tr = round1(data, MyDataset([pic], [lab]), p_bar)
        # f1.write(str(czz) + '\n')
        if args.m == 'l':
            f2 = open('logs/record1/tr_cifar10_fnn_tmp' + str(args.seed) + '.txt', 'a')
            if args.md == 'r':
                f2 = open('logs/record1/tr_cifar10_res18_tmp' + str(args.seed) + '.txt', 'a')
            elif args.md == 'v':
                f2 = open('logs/record1/tr_cifar10_vgg16_tmp' + str(args.seed) + '.txt', 'a')
            elif args.md == 'x':
                f2 = open('logs/record1/tr_cifar10_ResNeXt2x64_tmp' + str(args.seed) + '.txt', 'a')
            elif args.md == 'd':
                f2 = open('logs/record1/tr_cifar10_dense121_tmp' + str(args.seed) + '.txt', 'a')
        else:
            f2 = open('../../logs/record1/tr_cifar10_fnn.txt', 'a')
            if args.md == 'r':
                f2 = open('logs/record1/tr_cifar10_res18.txt', 'a')
            elif args.md == 'v':
                f2 = open('logs/record1/tr_cifar10_vgg16.txt', 'a')
            elif args.md == 'x':
                f2 = open('logs/record1/tr_cifar10_ResNeXt2x64.txt', 'a')
            elif args.md == 'd':
                f2 = open('logs/record1/tr_cifar10_dense121.txt', 'a')
        f2.write(str(float(tr)) + '\n')
        f2.close()
        coin += 1
        if args.m == 'l' and coin > lim:
            break
        if coin % 10 == 0:
            print(coin, ' ', time.time() - bg)
        # p_bar.update(1)
    # f1.close()
    # f2.close()



if __name__ == '__main__':
    main()


