import argparse
import distutils.util
import os
import sys
import yaml
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
#from torch.optim.lr_scheduler import _LRScheduler

config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
#from cifar_utils import transform_train, transform_test, Cifardata, DistillCifardata, WarmUpLR, ModelwNorm, \
#    transform_train_aug
#from cifar100.models.model_selector import get_network
from tinyimagenet_utils import transform_train, transform_train_aug, transform_test, TINdata, DistillTINdata, WarmUpLR, \
    ModelwNorm
from tinyimagenet.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def split_train(trainloader, model, criterion, optimizer, epoch, args, warmup_scheduler):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(trainloader):
        if epoch <= args.warmup:
            warmup_scheduler.step()

        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        #print(targets.max())

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        # compute gradient and do SGD step        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item() / 100.0, inputs.size()[0])
        top5.update(prec5.item() / 100.0, inputs.size()[0])

    return (losses.avg, top1.avg)


def split_test(testloader, model, criterion):
    model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)

        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item() / 100.0, inputs.size()[0])
        top5.update(prec5.item() / 100.0, inputs.size()[0])
    return (losses.avg, top1.avg)


def split_save_checkpoint(state, is_best, acc, split_name, checkpoint):
    if not os.path.isdir(os.path.join(checkpoint, split_name)):
        mkdir_p(os.path.join(checkpoint, split_name))
    # if is_best:
    filepath = os.path.join(checkpoint, split_name, 'model_last.pth.tar')
    if os.path.exists(filepath):
        tmp_ckpt = torch.load(filepath)
        best_acc = tmp_ckpt['best_acc']
        if best_acc > acc:
            return
    torch.save(state, filepath)


def get_opt_and_lrsch(args, model, num_epoch, num_iter, warmup):
    if args.model in ['hivit_tiny', 'hivit_small', 'hivit_base']:
        # from timm.scheduler.cosine_lr import CosineLRScheduler
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
        train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, num_epoch, eta_min=0.0  # , last_epoch=classifier_epochs
        )
    else:
        milestones = {
            'mobilenetv3_small_100': [60, 120, 160],
            'mobilenetv3_small_50': [60, 120, 160],
            'resnet18': [60, 120, 160],
            'resnet34': [60, 120, 160],
            'resnet50': [60, 120, 160],
            'resnet152': [60, 120, 160],
            'resnetl18': [60, 120, 160],
            'resnets18': [60, 120, 160],
            'resnett18': [60, 120, 160],
            'resnetw18': [60, 120, 160],
            'resnetxw18': [60, 120, 160],
            'resnet318': [60, 120, 160],
            'vgg11_bn': [60, 120, 160],
            'vgg19_bn': [60, 120, 160],
            'svgg19_bn': [60, 120, 160],
        }
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        train_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones[args.model], gamma=0.2
        )  # learning rate decay
    # warm scheduler
    warmup_scheduler = WarmUpLR(optimizer, math.ceil(num_iter / args.batch_step) * warmup)
    return optimizer, train_scheduler, warmup_scheduler


def main():
    parser = argparse.ArgumentParser(description='setting for cifar100')
    parser.add_argument('--K', type=int, default=25, help='total sub-models in split-ai')
    parser.add_argument('--L', type=int, default=10, help='non_model for each sample in split-ai')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--split_epochs', type=int, default=200,
                        help='training epochs for each single model in split-ai')
    parser.add_argument('--batch_step', type=int, default=1, help='batch accumulation steps')
    parser.add_argument('--print_epoch_splitai', type=int, default=5,
                        help='print splitai single model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=1, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')

    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--run_idx', type=int, default=1, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    split_model = args.K
    non_model = args.L
    attack_epochs = args.attack_epochs
    split_epochs = args.split_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    print_epoch_splitai = args.print_epoch_splitai
    load_name = str(split_model) + '_' + str(non_model)
    warmup = args.warmup
    num_worker = args.num_worker

    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    checkpoint_path = os.path.join(args.save_path, 'tinyimagenet', args.model, 'K_L', load_name)
    checkpoint_path_splitai = os.path.join(checkpoint_path, 'split_ai', 'aug' if args.data_aug else 'no_aug',
                                           str(args.run_idx))
    checkpoint_path_selena = os.path.join(checkpoint_path, 'selena', 'aug' if args.data_aug else 'no_aug',
                                          str(args.run_idx))
    print(checkpoint_path, checkpoint_path_selena)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_tr_attack = np.load(
        os.path.join(DATASET_PATH, 'partition', 'K_L', load_name, 'defender', 'tr_label.npy'))
    train_label_te_attack = np.load(
        os.path.join(DATASET_PATH, 'partition', 'K_L', load_name, 'defender', 'te_label.npy'))
    train_data = np.concatenate((train_data_tr_attack, train_data_te_attack), axis=0)
    train_label = np.concatenate((train_label_tr_attack, train_label_te_attack), axis=0)
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    # print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20, 0])
    print(train_label_te_attack[:20, 0])
    print(test_label[:20])
    print(ref_label[:20])

    # if data augmented
    if args.data_aug:
        trainset = TINdata(train_data, train_label, transform_train_aug)
    else:
        trainset = TINdata(train_data, train_label, transform_train)
    # trainset = TINdata(train_data, train_label, transform_train)
    #traintestset = TINdata(train_data, train_label, transform_test)
    testset = TINdata(test_data, test_label, transform_test)
    #refset = TINdata(ref_data, ref_label, transform_test)

    #trset = TINdata(train_data_tr_attack, train_label_tr_attack, transform_test)
    #teset = TINdata(train_data_te_attack, train_label_te_attack, transform_test)
    #alltestset = TINdata(all_test_data, all_test_label, transform_test)

    #trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    #teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    #alltestloader = torch.utils.data.DataLoader(alltestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    #traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    #refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    #original_train_label = train_label.copy()


    split_test_accs = []
    for i in range(split_model):
        split_best_acc = 0
        saved_epoch = 0

        best_acc = 0.00
        model_1 = get_network(arch=args.model, num_classes=200)
        model = model_1

        criterion = nn.CrossEntropyLoss()
        model = model.to(device, torch.float)
        criterion = criterion.to(device, torch.float)

        split_train_data_list = []
        split_train_label_list = []
        for ind in range(len(train_data)):
            tmp_ind = train_label[ind, -non_model:]
            if i not in tmp_ind:
                split_train_data_list.append(train_data[ind])
                split_train_label_list.append(train_label[ind, 0])
        split_train_data = np.array(split_train_data_list)
        split_train_label = np.array(split_train_label_list)
        # print first 20 labels for each subset i in splitai, for later checking
        print("split model: {:d},# of data: {:d}".format(i, len(split_train_data)))
        print(split_train_label[:20])

        split_trainset = TINdata(split_train_data, split_train_label, transform_train)
        split_trainloader = torch.utils.data.DataLoader(split_trainset, batch_size=batch_size, shuffle=True,
                                                        num_workers=num_worker)

        # iter_per_epoch = len(split_trainloader)
        # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        # train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones= [60, 120, 160], gamma=0.2) #learning rate decay
        # warmup_scheduler = WarmUpLR(optimizer,  iter_per_epoch * warmup)
        iter_per_epoch = len(trainloader)
        optimizer, train_scheduler, warmup_scheduler = get_opt_and_lrsch(
            args, model, split_epochs, iter_per_epoch, warmup
        )
        #trer = TrainRecorder()
        for epoch in range(1, split_epochs + 1):
            if epoch > 1:
                train_scheduler.step(epoch)

            _, split_train_acc = split_train(
                split_trainloader, model, criterion, optimizer, epoch, args, warmup_scheduler
            )
            #_, split_traintest_acc = split_test(traintestloader, model, criterion)
            _, split_test_acc = split_test(testloader, model, criterion)

            # record
            #trer.update(training_acc=split_train_acc, train_acc=split_traintest_acc, test_acc=split_test_acc)

            split_is_best = split_test_acc > split_best_acc
            split_best_acc = max(split_test_acc, split_best_acc)
            if split_is_best:
                saved_epoch = epoch

            if (epoch) % print_epoch_splitai == 0:
                print('Epoch: [{:d} | {:d}]: train acc:{:.4f}, test acc: {:.4f}. '.format(epoch, split_epochs,
                                                                                          split_train_acc,
                                                                                          split_test_acc))
            if split_is_best:
                pass
            # split_save_checkpoint({
            #         'epoch': epoch ,
            #         'state_dict': model.state_dict(),
            #         'acc': split_test_acc,
            #         'best_acc': split_best_acc,
            #         'optimizer' : optimizer.state_dict(),
            #     }, split_is_best, split_best_acc, split_name = str(i), checkpoint = checkpoint_path_splitai, filename='Depoch%d.pth.tar'%(epoch))
            sys.stdout.flush()
        split_save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'acc': split_test_acc,
            'best_acc': split_best_acc,
            'optimizer': optimizer.state_dict(),
        }, split_is_best, split_best_acc, split_name=str(i), checkpoint=checkpoint_path_splitai)

        print("model {:d} final saved epoch {:d}: {:.4f}".format(i, saved_epoch, split_best_acc))
        split_test_accs.append(split_best_acc)

    print("For a single model, test accuracy: {:.4f}".format(np.mean(split_test_accs)))


if __name__ == '__main__':
    main()
