import torch
import tqdm
import numpy as np
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
from torchvision import transforms
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.imagenet import ImageNet
from torch.utils.data import DataLoader
import torch.optim as optim
from lib.util.mytoolbag import setup_seed
import time
from lib.model.resnet import ResNet18
from lib.model.densenet import DenseNet121
from lib.model.resnext import ResNeXt29_2x64d
from lib.util.logger import Logger
from lib.model.vit import Vit
from lib.model.vgg import VGG
from lib.model.cifarnet import ImgNet as Net
from lib.model.effv2 import Effnet
import argparse


criterion = nn.CrossEntropyLoss()
SIZE = 30

former = transforms.Compose([
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomCrop(32, 4),
    # transforms.RandomResizedCrop(size),
    # transforms.Resize(256),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


def train_net(train_loader, net, optimizer, testloader, rd=50, scheduler=None, logger=None):
    accl = 0
    acctrain = 0
    epoch = 0
    for i in range(rd):
        bg = time.time()
        epoch += 1
        train_acc, train_loss, test_loss = 0, 0, 0
        net.train()
        pbar = tqdm.tqdm(total=len(train_loader))
        for _, data in enumerate(train_loader):
            inputs, labels = data
            inputs = torch.tensor(inputs, dtype=torch.float32)
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            train_acc += (predicted == labels.data.cpu().numpy()).sum()
            train_loss += float(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.update(1)
        acc = 0
        if epoch > 80 or i % 5 == 0:
            net.eval()
            pbar = tqdm.tqdm(total=len(train_loader))
            for data in testloader:
                images, labels = data
                images = torch.tensor(images, dtype=torch.float32)
                images = images.cuda()
                labels = labels.cuda()
                outputs = net(images)
                test_loss += float(criterion(outputs, labels))
                predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
                acc += (predicted == labels.data.cpu().numpy()).sum()
                pbar.update(1)
            accl = max(accl, acc)
            print('epoch : %d  ' % epoch, end='')
            print('acc : %.1f ' % acc, end='')
            print(time.time() - bg)
        if logger:
            logger.epoch_log2(epoch, train_acc / len(train_loader.dataset) * 100, train_loss / len(train_loader),
                              acc / len(testloader.dataset) * 100, test_loss / len(testloader))
        acctrain = max(acctrain, train_acc)
        if scheduler:
            scheduler.step()
    print(accl)
    return acctrain, accl


def round1(i, now_set, test_data=None, rd=90, args=None, logger=None):
    setup_seed(i)
    # net = models.vgg13(num_classes=10, drop_rate=0.2).cuda()
    net = Net().cuda()
    if args.md == 'v':
        net = models.vgg16().cuda()
    elif args.md == 'x':
        net = models.resnext50_32x4d().cuda()
    elif args.md == 'r':
        net = models.resnet18().cuda()
    elif args.md == 'd':
        net = models.densenet121().cuda()
    # if args.opt == 'SGD':
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)
    # else:
    # optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
    # torch.distributed.init_process_group(backend="nccl")
    # net = torch.nn.parallel.DistributedDataParallel(net)
    lr = 0.1
    if args.md == 'v' or args.md == 'f':
        lr = 0.01
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60], gamma=0.1)
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)

    return train_net(now_set, net, optimizer, test_data, rd=rd, scheduler=scheduler, logger=logger)


def cal_rnd():
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', default=1, type=int)
    parser.add_argument('-md', default='?', type=str)
    args = parser.parse_args()
    print(args)

    data = ImageNet()
    logger1 = Logger(name='img-base-' + args.md + '-')
    logger2 = Logger(name='base_imagenet_result', tim=False)
    acc, tacc = [], []

    for i in range(args.r):
        md = data.train_loader(data_set=data.total_train(), batch=32)
        test_data = data.test_set
        len1 = len(test_data)
        print(len1)
        test_data = data.train_loader(test_data, batch=256)
        acct, acce = round1(i, md, test_data=test_data, args=args, logger=logger1)
        acce /= len1 / 100
        acct /= len(md.dataset) / 100
        acc.append(acce)
        tacc.append(acct)
        print('test acc: ', sum(acc) / len(acc), np.std(acc), ' | train acc: ', np.mean(tacc), np.std(tacc))
    # f = open(args.p, 'a')
    logger2.info('md-' + args.md +
                 ' |test acc: ' + str(round(np.mean(acc), 2)) + '+' + str(round(np.std(acc), 3)) +
                 ' |train acc: ' + str(round(np.mean(tacc), 2)) + '+' + str(round(np.std(tacc), 3)))
    logger2.info('----------------------------------------------------------------------------------')

def main():
    cal_rnd()
    # cal_suf()


if __name__ == '__main__':
    main()
