'''
Adversarial Training 

'''
import os
import sys
from numpy.core.numeric import outer 
import torch
import pickle
import argparse
import torch.optim
import torch.nn as nn
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision.models as models
from utils import *
from sparselearning.core import Masking, CosineDecay
from sparselearning.pruning_utils import check_sparsity
import pandas as pd


parser = argparse.ArgumentParser(description='Generalization Gap')

########################## data setting ##########################
parser.add_argument('--data', type=str, default='data/cifar10', help='location of the data corpus', required=True)
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset [cifar10, cifar100, tinyimagenet]', required=True)

########################## model setting ##########################
parser.add_argument('--arch', type=str, default='resnet18', help='model architecture [resnet18, wideresnet, vgg16]', required=True)
parser.add_argument('--depth_factor', default=34, type=int, help='depth-factor of wideresnet')
parser.add_argument('--width_factor', default=10, type=int, help='width-factor of wideresnet')


########################## basic setting ##########################
parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
parser.add_argument('--resume_dir', help='The directory resume the trained models', default=None, type=str)
parser.add_argument('--pretrained', default=None, type=str, help='pretrained model')
parser.add_argument('--eval', action="store_true", help="evaluation pretrained model")
parser.add_argument('--print_freq', default=50, type=int, help='logging frequency during training')
parser.add_argument('--save_dir', help='The parent directory used to save the trained models', default=None, type=str)

parser.add_argument('--norm', default='linf', type=str, help='linf or l2')
parser.add_argument('--test_eps', default=8, type=float, help='epsilon of attack during testing')
parser.add_argument('--test_step', default=20, type=int, help='itertion number of attack during testing')
parser.add_argument('--test_gamma', default=2, type=float, help='step size of attack during testing')
parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)')

########################## training setting ##########################
parser.add_argument('--consistency', action='store_true', help='apply consistency regularization')
parser.add_argument('--robust_friendly', action='store_true', help='apply robust friendly dataset')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')

########################## Small Dense Test #############################
parser.add_argument('--small_dense', action='store_true', help='Enable small dense mode. Default: True.')
parser.add_argument('--small_dense_rate', type=float, default=0.8, help='The density of small density, support 0.05 0.1 0.2 0.6 0.8')

########################## SWA setting ##########################
parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=55, metavar='N', help='SWA start epoch number (default: 55)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

########################## KD setting ##########################
parser.add_argument('--lwf', action='store_true', help='lwf usage flag (default: off)')
parser.add_argument('--t_weight1', type=str, default=None, required=False, help='pretrained weight for teacher1')
parser.add_argument('--t_weight2', type=str, default=None, required=False, help='pretrained weight for teacher2')
parser.add_argument('--coef_ce', type=float, default=0.3, help='coef for CE')
parser.add_argument('--coef_kd1', type=float, default=0.1, help='coef for KD1')
parser.add_argument('--coef_kd2', type=float, default=0.6, help='coef for KD2')
parser.add_argument('--temperature', type=float, default=2.0, help='temperature of knowledge distillation loss')
parser.add_argument('--lwf_start', type=int, default=0, metavar='N', help='start point of lwf (default: 200)')
parser.add_argument('--lwf_end', type=int, default=200, metavar='N', help='end point of lwf (default: 200)')

########################## static sparse setting ##########################
parser.add_argument('--static_sparse', action='store_true', help='Enable static sparse mode. Default: True.')
parser.add_argument('--sparse_type', type=str, default='rp', help='static sparse mask initialization. choose from: rp omp gmp tp snip')

def get_generalization_gap(model_path, args):
    print(model_path)
    #final 
    train_loader, val_loader, test_loader, model, swa_model, teacher1, teacher2 = setup_dataset_models(args)
    model.cuda()
    final_checkpoint = torch.load(os.path.join(model_path, 'checkpoint.pth.tar'), map_location = torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(final_checkpoint['state_dict'])
    criterion = nn.CrossEntropyLoss()

    final_train_ra, _ = test_adv(train_loader, model, criterion, args)
    final_test_ra, _ = test_adv(test_loader, model, criterion, args)

    #best
    train_loader, val_loader, test_loader, model, swa_model, teacher1, teacher2 = setup_dataset_models(args)
    model.cuda()
    best_checkpoint = torch.load(os.path.join(model_path, 'model_RA_best.pth.tar'), map_location = torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(best_checkpoint['state_dict'])
    criterion = nn.CrossEntropyLoss()

    best_train_ra, _ = test_adv(train_loader, model, criterion, args)
    best_test_ra, _ = test_adv(test_loader, model, criterion, args)

    val_best = best_train_ra - best_test_ra
    final = final_train_ra - final_test_ra
    diff = val_best - final

    result = {}
    result['file'] = [os.path.basename(model_path)]
    result['generalization_gap'] = [val_best]
    result['final_gap'] = [final]   
    result['diff3'] = [diff]

    return pd.DataFrame(result)


if __name__ == '__main__':
    args = parser.parse_args()
    args.test_eps = args.test_eps / 255
    args.test_gamma = args.test_gamma / 255

    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        print('set random seed = ', args.seed)
        setup_seed(args.seed)


    model_paths = [
        'res_0911_batch/flying_resnet18_cifar10_igq_T900_d0.1_dr0.5_gradient_p0.4_g0.05_b32_e5_r0.5_seed17ee5',
        'res_0911_batch/flying_resnet18_cifar10_igq_T900_d0.1_dr0.5_gradient_p0.4_g0.05_b64_e5_r0.5_seed10d55',
        'res_0911_batch/flying_resnet18_cifar10_igq_T900_d0.1_dr0.5_gradient_p0.4_g0.05_b256_e5_r0.5_seed1bd9a',
        'res_0911_batch/flying_resnet18_cifar10_igq_T900_d0.1_dr0.5_gradient_p0.4_g0.05_b512_e5_r0.5_seed1ed46'
    ]
    batchs = [32, 64, 256, 512]

    df_total = pd.DataFrame(columns=['file', 'generalization_gap', 'final_gap', 'diff3'])

    for b, path in zip(batchs, model_paths):
        if not os.path.exists(path):
            continue
        args.batch_size = b
        df = get_generalization_gap(path, args)
        df_total = df_total.append(df, ignore_index=True)

    df_total = df_total.sort_values(by='file',ascending=False)

    print(df_total)
    file_name = 'gap_batch_statis.xlsx'
    save_dir = './statis'
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, file_name)
    print('save path: ', save_path)
    df_total.to_excel(save_path, index=False)