import torch
from lib.util.logger import Logger
import tqdm
import numpy as np
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import Cifar100
from torch.utils.data import DataLoader
from lib.model.cifarnet import Net
import torch.optim as optim
from lib.util.mytoolbag import setup_seed
import time
from lib.model.resnet import ResNet18
from lib.model.densenet import DenseNet121
from lib.model.resnext import ResNeXt29_2x64d
from lib.model.vgg import VGG
from lib.model.cifarnet import Net
import argparse


criterion = nn.CrossEntropyLoss()
SIZE = 30


def train_net(train_loader, net, optimizer, testloader, rd=50, scheduler=None, logger=None):
    accl = 0
    acctrain = 0
    epoch = 0
    for i in range(rd):
        bg = time.time()
        epoch += 1
        train_acc, train_loss, test_loss = 0, 0, 0
        net.train()
        # pbar = tqdm.tqdm(total=len(train_loader))
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            train_acc += (predicted == labels.data.cpu().numpy()).sum()
            train_loss += float(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # pbar.update(1)
        acc = 0
        net.eval()
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(images)
            test_loss += float(criterion(outputs, labels))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc += (predicted == labels.data.cpu().numpy()).sum()
        accl = max(accl, acc)
        print('epoch : %d  ' % epoch, end='')
        print('acc : %.1f ' % acc, end='')
        print(time.time() - bg)
        if logger:
            logger.epoch_log2(epoch, train_acc / len(train_loader.dataset) * 100, train_loss / len(train_loader),
                              acc / len(testloader.dataset) * 100, test_loss / len(testloader))
        acctrain = max(acctrain, train_acc)
        if scheduler:
            scheduler.step()
    print(accl)
    return acctrain, accl


def round1(i, now_set, test_data=None, rd=200, args=None, logger=None):
    setup_seed(i)
    # net = models.vgg13(num_classes=10, drop_rate=0.2).cuda()
    net = Net(num_cls=100).cuda()
    lr = 0.1
    if args.md == 'f':
        lr = 0.002
    if args.md == 'v':
        net = VGG('VGG16', num_cls=100).cuda()
        lr = 0.01
    elif args.md == 'x':
        net = ResNeXt29_2x64d(num_cls=100).cuda()
    elif args.md == 'r':
        net = ResNet18(num_cls=100).cuda()
    elif args.md == 'd':
        net = DenseNet121(num_cls=100).cuda()
    # if args.opt == 'SGD':
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    # else:
    #     optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4)
    #     scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150], gamma=0.5)
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)

    return train_net(now_set, net, optimizer, test_data, rd=rd, scheduler=scheduler, logger=logger)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', default=128, type=int)
    parser.add_argument('-md', default='_', type=str)
    parser.add_argument('-pro', default=0, type=int)
    args = parser.parse_args()
    print(args)

    if args.pro:
        data = Cifar100(pro=True)
    else:
        data = Cifar100()
    logger1 = Logger(name='1train-' + args.md + '-' + str(args.pro) + '-' + str(args.b))
    logger2 = Logger(name='base_result', tim=False)
    acc, tacc = [], []
    md = data.train_loader(batch=args.b)
    test_data = data.test_set
    test_data = data.train_loader(test_data, batch=args.b)
    len1 = len(test_data)
    acct, acce = round1(20, md, test_data=test_data, args=args, logger=logger1)
    acce /= 100
    acct /= len(md.dataset) / 100
    acc.append(acce)
    tacc.append(acct)
    print('test acc: ', sum(acc) / len(acc), np.std(acc), ' | train acc: ', np.mean(tacc), np.std(tacc))
    logger2.info('base-train-cifar10-' + 'md-' + args.md + '|b-' + str(args.b) + '|pro:' + str(args.pro) +
                 ' |test acc: ' + str(round(np.mean(acc), 2)) + '+' + str(round(np.std(acc), 3)) +
                 ' |train acc: ' + str(round(np.mean(tacc), 2)) + '+' + str(round(np.std(tacc), 3)) + '\n')
    logger2.info('----------------------------------------------------------------------------------')
    # f.write('cifar10_m-' + args.m + '_md-' + args.md + '_tr-' + args.tr + '_b-' + str(args.b) + ' acc: ' +
    #         str(np.mean(acc)) + ' std: ' + str(np.std(acc)) + '\n')
    # f.close()


if __name__ == '__main__':
    main()

"""
CUDA_VISIBLE_DEVICES=0 nohup python -u 1train.py -m tr -p o1.txt > o1o.out 2>&1 &
CUDA_VISIBLE_DEVICES=0 nohup python 1train.py -m rnd -p o2.txt > o2.out 2>&1 &
CUDA_VISIBLE_DEVICES=3 nohup python -u 1train.py -p o1.txt > o500.out 2>&1 &


srun -p NLP --gres=gpu:1 -N1 python -u 1train.py -m tr -p o12.txt > vd2o1o.out 2>&1 &





srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md a -b 128 > abase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md v -b 128 > vbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md r -b 128 > rbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md x -b 128 > xbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md d -b 128 > dbase.out 2>&1 &

srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md a -b 128 -pro 1 > apbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md v -b 128 -pro 1 > vpbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md r -b 128 -pro 1 > rpbase.out 2>&1 &
srun -p NLP --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md x -b 128 -pro 1 > xpbase.out 2>&1 &
srun -p NLP --quotatype=spot --gres=gpu:1 -N1 python -u 9cifar100_bse.py -md d -b 128 -pro 1 > dpbase.out 2>&1 &

43.38
50.?
"""