import torch
import tqdm
import numpy as np
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData, Cifar100
from torch.utils.data import DataLoader
from lib.model.densenet import DenseNet121 as Net
import torch.optim as optim
from lib.util.mytoolbag import setup_seed
import time
from lib.model.densenet import densenet_cifar
import argparse
from lib.util.logger import Logger
from transformers import ViTFeatureExtractor, ViTModel
from lib.model.vit import Vit

criterion = nn.CrossEntropyLoss()
SIZE = 30


def train_net(train_loader, net, optimizer, testloader, rd=50, scheduler=None, logger=None):
    accl = 0
    acctrain = 0
    epoch = 0
    # with tqdm.tqdm(total=tmp_tot) as p_bar:
    for i in range(rd):
        bg = time.time()
        epoch += 1
        train_acc, train_loss, test_loss = 0, 0, 0
        net.train()
        p_bar = tqdm.tqdm(total=len(train_loader))
        for data in train_loader:
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            train_acc += (predicted == labels.data.cpu().numpy()).sum()
            train_loss += float(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            p_bar.update(1)
        acc = 0
        net.eval()
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(images)
            test_loss += float(criterion(outputs, labels))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc += (predicted == labels.data.cpu().numpy()).sum()
        accl = max(accl, acc)
        print('epoch : %d  ' % epoch, end='')
        print('acc : %.1f ' % acc, end='')
        print(time.time() - bg)
        if logger:
            logger.epoch_log2(epoch, train_acc / len(train_loader.dataset) * 100, train_loss / len(train_loader),
                              acc / len(testloader.dataset) * 100, test_loss / len(testloader))
        acctrain = max(acctrain, train_acc)
        if scheduler:
            scheduler.step()
    print(accl)
    return acctrain, accl


def round1(i, now_set, data, rd=50, args=None, logger=None):
    setup_seed(i)
    net = Vit(num_cls=100).cuda()
    test_data = data.train_loader(data_set=data.test_set, batch=args.b)
    optimizer = optim.Adam(net.parameters(), lr=0.00005)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)

    return train_net(now_set, net, optimizer, test_data, rd=rd, scheduler=scheduler, logger=logger)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', default='None')
    """far, near, None"""
    parser.add_argument('-p', default='o1.txt')
    parser.add_argument('-b', default=32, type=int)
    parser.add_argument('-r', default=1, type=int)
    args = parser.parse_args()

    data = CifarData(size=224)
    logger1 = Logger(name='vit-cifar10-base')
    logger2 = Logger(name='base_result', tim=False)

    acc, tacc = [], []
    for i in range(args.r):
        md = data.train_loader(batch=args.b)
        acct, acce = round1(i, md, data=data, args=args, rd=5, logger=logger1)
        acce /= 100
        acct /= len(md.dataset) / 100
        acc.append(acce)
        tacc.append(acct)
        print('test acc: ', sum(acc) / len(acc), np.std(acc), ' | train acc: ', np.mean(tacc), np.std(tacc))
    logger2.info('vit-cifar10' +
                 ' |test acc: ' + str(round(np.mean(acc), 2)) + '+' + str(round(np.std(acc), 3)) +
                 ' |train acc: ' + str(round(np.mean(tacc), 2)) + '+' + str(round(np.std(tacc), 3)) + '\n')
    logger2.info('----------------------------------------------------------------------------------')


if __name__ == '__main__':
    main()
