import torch
import tqdm
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData, MnistData
from lib.dataset.mydata import Cifar100
from torch.utils.data import DataLoader
from lib.model.resnet import ResNet34, ResNet18
from lib.dataset.myset import MyDataset
from lib.model.cifarnet import Net
from lib.model.mnistnet import MNISTNet
from lib.model.mnistnet import DenseNet121 as MDenseNet121
from lib.model.mresnext_nbn import ResNeXt29_2x64d as MResNeXt
from lib.model.mnistnet import VGG as MVGG
from lib.model.mresnet_nbn import ResNet18 as MResNet18
import argparse
import torch.optim as optim
import time
from lib.util.mytoolbag import setup_seed
from lib.util.t426_train_net import train_net_o


criterion = nn.CrossEntropyLoss()
SIZE = 30


def cal_ntk_c(now_set, net, optimizer, tot_para):
    ntk_loader = DataLoader(now_set, batch_size=2)
    vectors = []
    tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:

    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vectors.append(get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para))
        # p_bar.update(1)
    siz = len(vectors)
    mat = np.zeros((siz, siz), dtype=float)
    # print('get H matrix')
    #p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    for i in range(siz):
        for j in range(i, siz):
            mat[i][j] = mat[j][i] = vectors[i].dot(vectors[j])
            #p_bar.update(1)
    print(mat)
    czz = np.min(np.linalg.eigvalsh(mat))
    return czz


def cal_tr(now_set, net, optimizer, tot_para, p_bar=None):
    ntk_loader = DataLoader(now_set, batch_size=2)
    # vectors = []
    # tmp_tot = len(ntk_loader)
    #with tqdm.tqdm(total=tmp_tot) as p_bar:
    tmp = 0
    for i, data in enumerate(ntk_loader):
        inputs, labels = data
        inputs = Variable(inputs)
        inputs = inputs.cuda()
        outputs = net(inputs)
        for k in range(10):
            vec = get_gradient_tensor(net, optimizer, outputs[0][k], tot=tot_para)
            tmp += vec.dot(vec)
        # p_bar.update(1)
    # siz = len(vectors)
    # print('get H matrix')
    # p_bar = tqdm.tqdm(total=siz * (siz + 1) // 2)
    # for i in range(siz):
    #     tmp += float(vectors[i].dot(vectors[i]))
            #p_bar.update(1)
    return tmp


def round1(net, now_set, p_bar=None):
    # setup_seed(4)
    tot_para = cal_para(net)
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # data.train_loader(batch=1000)

    czz = 0 # cal_ntk_c(now_set, net, optimizer, tot_para)

    tr = cal_tr(now_set, net, optimizer, tot_para, p_bar=p_bar)

    return czz, tr

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-md', default='?', type=str)
    args = parser.parse_args()
    print(args)

    data = MnistData()
    # for item in data.train_set:
    #     pic, lab = item
    #     print(pic.shape)
    #     exit(0)
    # f1 = open('logs/record1/ntk_test_cifar_s10.txt', 'a')
    # f2 = open('logs/record1/tr_cifar_s1.txt', 'a')
    # net = DenseNet121(num_cls=10).cuda()
    # net = Net(num_cls=10).cuda()
    net = MNISTNet().cuda()
    if args.md == 'r':
        net = MResNet18().cuda()
    elif args.md == 'v':
        net = MVGG('VGG16').cuda()
    elif args.md == 'x':
        net = MResNeXt().cuda()
    elif args.md == 'd':
        net = MDenseNet121().cuda()
    # net = ResNeXt29_2x64d(num_cls=10).cuda()
    # net = ResNet18(num_cls=10).cuda()
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    # train_net_o(data.train_loader(batch=100), net, optimizer, data.train_loader(data_set=data.test_set, batch=500),
    #             epoch1=50, scheduler=scheduler)
    bg = time.time()
    coin = 0
    for item in data.train_set:
        # con -= 1
        coin += 1
        # if coin % 100 == 0:
        #     print(coin)
        # if con > 0:
        #     continue
        pic, lab = item
        czz, tr = round1(net, MyDataset([pic], [lab]))
        # czz, tr = round1(data, MyDataset([pic], [lab]), p_bar)
        # f1.write(str(czz) + '\n')
        f2 = open('logs/record1/tr_mnist_fnn.txt', 'a')
        if args.md == 'r':
            f2 = open('logs/record1/tr_mnist_res18.txt', 'a')
        elif args.md == 'v':
            f2 = open('logs/record1/tr_mnist_vgg16.txt', 'a')
        elif args.md == 'x':
            f2 = open('logs/record1/tr_mnist_ResNeXt2x64.txt', 'a')
        elif args.md == 'd':
            f2 = open('logs/record1/tr_mnist_dense121.txt', 'a')
        f2.write(str(float(tr)) + '\n')
        f2.close()
        if coin % 10 == 0:
            print('resneXt', coin, ' ', time.time() - bg)
        # p_bar.update(1)
    # f1.close()
    # f2.close()



if __name__ == '__main__':
    main()


