import torch
from lib.util.logger import Logger
import tqdm
import numpy as np
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
import random
from lib.util.mytoolbag import cal_para, get_gradient_tensor, multi_tensor_gra
from lib.dataset.mydata import CifarData
from torch.utils.data import DataLoader
from lib.model.cifarnet import Net
import torch.optim as optim
from lib.util.mytoolbag import setup_seed
import time
from lib.model.resnet import ResNet18
from lib.model.densenet import DenseNet121
from lib.model.resnext import ResNeXt29_2x64d
from lib.model.vgg import VGG
from lib.model.vit import Vit
from lib.model.effv2 import Effnet
from lib.model.cifarnet import Net
import argparse


criterion = nn.CrossEntropyLoss()
SIZE = 30


def train_net(train_loader, net, optimizer, testloader, rd=50, scheduler=None, logger=None):
    accl = 0
    acctrain = 0
    epoch = 0
    for i in range(rd):
        bg = time.time()
        epoch += 1
        train_acc, train_loss, test_loss = 0, 0, 0
        net.train()
        pbar = tqdm.tqdm(total=len(train_loader))
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            train_acc += (predicted == labels.data.cpu().numpy()).sum()
            train_loss += float(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.update(1)
        acc = 0
        net.eval()
        pbar = tqdm.tqdm(total=len(testloader))
        for data in testloader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = net(images)
            test_loss += float(criterion(outputs, labels))
            predicted = torch.max(outputs, 1)[1].data.cpu().numpy()
            acc += (predicted == labels.data.cpu().numpy()).sum()
            pbar.update(1)
        accl = max(accl, acc)
        print('epoch : %d  ' % epoch, end='')
        print('acc : %.1f ' % acc, end='')
        print(time.time() - bg)
        if logger:
            logger.epoch_log2(epoch, train_acc / len(train_loader.dataset) * 100, train_loss / len(train_loader),
                              acc / len(testloader.dataset) * 100, test_loss / len(testloader))
        acctrain = max(acctrain, train_acc)
        if scheduler:
            scheduler.step()
    print(accl)
    return acctrain, accl


def round1(i, now_set, test_data=None, rd=20, args=None, logger=None):
    setup_seed(i)
    # net = models.vgg13(num_classes=10, drop_rate=0.2).cuda()
    net = Vit().cuda()
    if args.md == 'ef':
        net = Effnet().cuda()
    # if args.opt == 'SGD':
    # optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)
    # else:
    if args.md == 'vt':
        optimizer = optim.Adam(net.parameters(), lr=0.00005, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)
        rd = 20
    else:
        optimizer = optim.AdamW(net.parameters(), lr=0.001, weight_decay=0.005)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, epochs=20, steps_per_epoch=5000 // args.t)
        rd = 20
    # optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=rd)

    return train_net(now_set, net, optimizer, test_data, rd=rd, scheduler=scheduler, logger=logger)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', default='rnd', type=str)
    parser.add_argument('-p', default='output.txt')
    parser.add_argument('-b', default=500, type=int)
    parser.add_argument('-r', default=5, type=int)
    parser.add_argument('-t', default=32, type=int)
    parser.add_argument('-md', default='_', type=str)
    parser.add_argument('-tr', default='_', type=str)
    args = parser.parse_args()
    path1 = './grad_norm/tr_cifar10_fnn.txt'
    if args.tr == 'x':
        path1 = './grad_norm/tr_cifar_train_ResNeXt29_2x64d.txt'
    elif args.tr == 'r':
        path1 = './grad_norm/tr_cifar10_res18.txt'
    elif args.tr == 'd':
        path1 = './grad_norm/tr_cifar_denesnet.txt'
    elif args.tr == 'v':
        path1 = './grad_norm/tr_cifar_train_vgg.txt'
    print(args)
    print(path1)

    if args.md == 'vt':
        data = CifarData(size=224)
    else:
        data = CifarData(size=384)
    logger1 = Logger(name='1train-transfer-' + args.md + '-' + args.tr + '-' + str(args.b))
    logger2 = Logger(name='tf_cifar10_result', tim=False)
    md = data.train_loader(batch=64)
    acc, tacc = [], []
    rand1 = 5000 - args.b
    if args.m == 'tr':
        print(rand1, rand1 + args.b)
        md = data.get_tr_suf(size=args.b, l=rand1, r=args.b + rand1, path=path1, batch=args.t)
    for i in range(args.r):
        # rand1 = random.randint(0, 5000 - args.b - 1)
            # md = data.get_tr_suf(size=args.b, l=args.b * rand1, r=args.b * (rand1 + 1))
        if args.m == 'rnd':
            md = data.get_rnd_suf(size=args.b, batch=args.t)
        print(len(md.dataset), rand1)
        # if args.t == 0:
        #     test_data = data.test_set
        # else:
        #     test_data = data.get_sorted_testdata(l=args.b * rand1 * 2, r=args.b * (rand1 + 1) * 2)
        test_data = data.test_set
        len1 = len(test_data)
        print(len(test_data))
        test_data = data.train_loader(test_data, batch=args.t)
        acct, acce = round1(i, md, test_data=test_data, args=args, logger=logger1)
        acce /= len1 / 100
        acct /= len(md.dataset) / 100
        acc.append(acce)
        tacc.append(acct)
        print('test acc: ', sum(acc) / len(acc), np.std(acc), ' | train acc: ', np.mean(tacc), np.std(tacc))
    # f = open(args.p, 'a')
    logger2.info('m-' + args.m + '|md-' + args.md + '|tr-' + args.tr + '|b-' + str(args.b) +
                 ' |test acc: ' + str(round(np.mean(acc), 2)) + '+' + str(round(np.std(acc), 3)) +
                 ' |train acc: ' + str(round(np.mean(tacc), 2)) + '+' + str(round(np.std(tacc), 3)))
    logger2.info('----------------------------------------------------------------------------------')
    # f.write('cifar10_m-' + args.m + '_md-' + args.md + '_tr-' + args.tr + '_b-' + str(args.b) + ' acc: ' +
    #         str(np.mean(acc)) + ' std: ' + str(np.std(acc)) + '\n')
    # f.close()


if __name__ == '__main__':
    main()
