import torch
import tqdm
import numpy as np
from lib.util.logger import Logger
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData
from lib.dataset.mydata import MnistData
from torch.utils.data import DataLoader
import torch.optim as optim
from lib.util.mytoolbag import setup_seed
import time
from lib.model.mnistnet import ResNet18, ResNeXt29_2x64d, VGG, DenseNet121
from lib.model.mnistnet import MNISTNet as Net
import argparse


criterion = nn.CrossEntropyLoss()
SIZE = 30


def train_net(train_loader, net, optimizer, testloader, rd=50, scheduler=None, logger=None):
    accl = 0
    acctrain = 0
    epoch = 0
    for i in range(rd):
        bg = time.time()
        epoch += 1
        train_acc, train_loss, test_loss = 0, 0, 0
        net.train()
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            train_acc += (predicted == labels.data.cpu().numpy()).sum()
            train_loss += float(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc = 0
        net.eval()
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(images)
            test_loss += float(criterion(outputs, labels))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc += (predicted == labels.data.cpu().numpy()).sum()
        accl = max(accl, acc)
        print('epoch : %d  ' % epoch, end='')
        print('acc : %.1f ' % acc, end='')
        print(time.time() - bg)
        if logger:
            logger.epoch_log2(epoch, train_acc / len(train_loader.dataset) * 100, train_loss / len(train_loader),
                              acc / len(testloader.dataset) * 100, test_loss / len(testloader))
        acctrain = max(acctrain, train_acc)
        if scheduler:
            scheduler.step()
    print(accl)
    return acctrain, accl


def round1(i, now_set, test_data=None, rd=50, args=None, logger=None):
    setup_seed(i)
    # net = models.vgg13(num_classes=10, drop_rate=0.2).cuda()
    net = Net().cuda()
    if args.md == 'v':
        net = VGG('VGG16').cuda()
    elif args.md == 'x':
        net = ResNeXt29_2x64d().cuda()
    elif args.md == 'r':
        net = ResNet18().cuda()
    elif args.md == 'd':
        net = DenseNet121().cuda()
    # if args.opt == 'SGD':
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)
    # else:
    #     optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4)
    #     scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150], gamma=0.5)
    # if logger:
    #     logger.info('train size:' + str(len(now_set.dataset)))
    #     logger.info(str(args.lr) + ' ' + args.opt)
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)

    return train_net(now_set, net, optimizer, test_data, rd=rd, scheduler=scheduler, logger=logger)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', default='rnd', type=str)
    parser.add_argument('-p', default='mnist_output.txt')
    parser.add_argument('-b', default=500, type=int)
    parser.add_argument('-r', default=5, type=int)
    parser.add_argument('-md', default='?', type=str)
    parser.add_argument('-tr', default='?', type=str)
    # parser.add_argument('-lr', default=0.1, type=float)
    # parser.add_argument('-opt', default='SGD', type=str)
    args = parser.parse_args()
    path1 = './grad_norm/tr_mnist_fnn_d.txt'
    if args.tr == 'x':
        path1 = './grad_norm/tr_mnist_ResNeXt2x64.txt'
    elif args.tr == 'r':
        path1 = './grad_norm/tr_mnist_res18.txt'
    elif args.tr == 'd':
        path1 = './grad_norm/tr_mnist_dense121_d.txt'
    elif args.tr == 'v':
        path1 = './grad_norm/tr_mnist_vgg16_d.txt'
    print(args)
    print(path1)

    data = MnistData()
    logger1 = Logger(name='mni-train_' + args.md + '_' + args.tr + str(args.b))
    logger2 = Logger(name='train_mnist_result', tim=False)
    md = data.train_loader(batch=64)
    acc, tacc = [], []
    len1 = 0
    for i in range(args.r):
        # rand1 = random.randint(0, 5000 - args.b - 1)
        rand1 = 5000 - args.b
        if args.m == 'tr':
            print(rand1, rand1 + args.b)
            md = data.get_tr_suf(size=args.b, l=rand1, r=args.b + rand1, path=path1)
            # md = data.get_tr_suf(size=args.b, l=args.b * rand1, r=args.b * (rand1 + 1))
        elif args.m == 'rnd':
            md = data.get_rnd_suf(size=args.b)
        test_data = data.test_set
        len1 = len(test_data)
        print(len(test_data))
        test_data = data.train_loader(test_data, batch=100)
        acct, acce = round1(i, md, test_data=test_data, args=args, logger=logger1)
        acce /= len1 / 100
        acct /= len(md.dataset) / 100
        acc.append(acce)
        tacc.append(acct)
        print('test acc: ', sum(acc) / len(acc), np.std(acc), ' | train acc: ', np.mean(tacc), np.std(tacc))
    # f = open(args.p, 'a')
    logger2.info('m-' + args.m + '|md-' + args.md + '|tr-' + args.tr + '|b-' + str(args.b) +
                 ' |test acc: ' + str(round(np.mean(acc), 2)) + '+' + str(round(np.std(acc), 3)) +
                 ' |train acc: ' + str(round(np.mean(tacc), 2)) + '+' + str(round(np.std(tacc), 3)))
    logger2.info('----------------------------------------------------------------------------------')
    # f.write('cifar10_m-' + args.m + '_md-' + args.md + '_tr-' + args.tr + '_b-' + str(args.b) + ' acc: ' +
    #         str(np.mean(acc)) + ' std: ' + str(np.std(acc)) + '\n')
    # f.close()


if __name__ == '__main__':
    main()

"""

"""