import matplotlib as mpl

mpl.use('Agg')

import torch
import torch.nn as nn

import os

from utils.model_normalization import ImageNetWrapper
from utils.models.models_224x224.resnet_224 import resnet50
from utils.models.models_224x224.efficientnet import EfficientNet
import utils.datasets as dl
import utils.train_types as tt
import utils.train_types.schedulers as schedulers
import utils.train_types.optimizers as optimizers
import utils.models.model_factory_224 as factory
import ssl_utils as ssl

import argparse
from distutils.util import strtobool

parser = argparse.ArgumentParser(description='Define hyperparameters.', prefix_chars='-')

parser.add_argument('--gpu', '--list', nargs='+', default=[0],
                    help='GPU indices, if more than 1 parallel modules will be called')
parser.add_argument('--net', type=str, default='tresnetm', help='Architecture')
parser.add_argument('--dataset', type=str, default='test_set', help='imagenet')
parser.add_argument('--rho', type=float, default=1, help='Kernel falloff')
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate for training')
parser.add_argument('--train_bs', type=int, default=128, help='Training batch out_size')
parser.add_argument('--od_weight', type=float, default=1, help='Weight for out-distribution term in ACET (derivates)')
parser.add_argument('--lr_smart', type=int, default=0, help='Epoch wide attack lr adaptation ')
parser.add_argument('--decay', type=float, default=5e-4, help='weight decay for base density_model.')
parser.add_argument('--cycle_length', type=int, default=220, help='total number of cycle_length')
parser.add_argument('--train_type', type=str, default='plain', help='train on plain or CAL')
parser.add_argument('--steps', type=int, default=7, help='steps in PGD attack')
parser.add_argument('--id_pgd', type=str,
                    default='argmin', help='PGD variation for InDistribuion/Adversarial attacks: pgd, argmin, monotone')
parser.add_argument('--od_pgd', type=str,
                    default='argmin', help='PGD variation for Out Distribuion attacks: pgd, argmin, monotone')
parser.add_argument('--train_clean', dest='train_clean', type=lambda x: bool(strtobool(x)),
                    default=False, help='whether to use ref_data augmentation')
parser.add_argument('--cal_obj', type=str, default='kl',
                    help=('CCAT divergence: kl, bhattacharyya, renyi_alpha'))
parser.add_argument('--adv_obj', type=str, default='crossEntropy',
                    help=('Objective to optimize in the inner loop of adversarial training'
                          'logitsDiff | crossEntropy'))
parser.add_argument('--a_obj', type=str, default='KL',
                    help=('only for ACET; what objective the adversary has'
                          'conf | log_conf | entropy | KL | bhattacharyya'))
parser.add_argument('--t_obj', type=str, default='KL',
                    help=('for ACET/CEDA; what should the second term in the training objective be'
                          'conf | log_conf | entropy | KL | bhattacharyya'))
parser.add_argument('--tiny', type=str, default='shuffle',
                    help=('Shuffling method of tiny images: linear, randomOffset, shuffle'))
parser.add_argument('--augm', type=str, default='default',
                    help=('Augmentation type: test, default, autoaugment'))
parser.add_argument('--norm', type=str, default='l2',
                    help=('l2 or linf'))
parser.add_argument('--continue', dest='continue_trained', nargs=2, type=str,
                    default=None, help='Filename of density_model to load an epoch')
parser.add_argument('--od_dataset', type=str, default='openImages',
                    help=('tinyImages or cifar100'))
parser.add_argument('--od_bs_factor', default=1, type=float, help='OD batch out_size factor')
parser.add_argument('--nesterov', dest='nesterov', type=lambda x: bool(strtobool(x)),
                    default=True, help='Nesterov SGD')
parser.add_argument('--stochastic_depth', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.0)

hps = parser.parse_args()
#
if len(hps.gpu) == 0:
    device = torch.device('cpu')
    print('Warning! Computing on CPU')
elif len(hps.gpu) == 1:
    device = torch.device('cuda:' + str(hps.gpu[0]))
else:
    device_ids = [int(i) for i in hps.gpu]
    device = torch.device('cuda:' + str(min(device_ids)))

# parameters
# https://arxiv.org/pdf/1906.09453.pdf
lr = hps.lr
bs = hps.bs
epochs = hps.epochs
attack_steps = hps.steps
cal_div = hps.cal_obj
rho = hps.rho
lam = hps.lam
id_pgd = hps.id_pgd.lower()
od_pgd = hps.od_pgd.lower()
adv_obj = hps.adv_obj.lower()
lr_smart = bool(hps.lr_smart)
train_clean = bool(hps.train_clean)
tiny_shuffle = hps.tiny.lower()
network_name = hps.net.lower()
augm = hps.augm.lower()
norm = hps.norm.lower()
nesterov = hps.nesterov
od_dataset = hps.od_dataset.lower()


# Load density_model
img_size = 224
num_classes = 100

model_root_dir = 'ImageNet100Models'
logs_root_dir = 'ImageNet100Logs'

model_arguments = {}
if hps.stochastic_depth > 0.0:
    model_arguments['drop_path_rate'] = hps.stochastic_depth
if hps.dropout > 0.0:
    model_arguments['drop_rate'] = hps.dropout

model, model_name, model_config = factory.build_model(network_name, num_classes, **model_arguments)
model_dir = os.path.join(model_root_dir, model_name)
log_dir = os.path.join(logs_root_dir, model_name)

# load dataset
od_bs = int(hps.od_bs_factor * bs)


class_tpr_min = '0.900'
od_exclusion_threshold = '0.900'

dataset_classifications_path = f'DatasetClassifications/OpenImagesImageNet100/'


if hps.train_type in ['CEDA', 'ACET', 'AdvACET', 'ADVACET', 'CEDAExtra']:
    id_config = {}
    od_config = {}
    loader_config = {'ID config': id_config, 'OD config': od_config}

    if hps.dataset == 'test_set':
        train_loader = ssl.get_ImageNet100_trainVal(train=True, batch_size=bs, shuffle=True,
                                          augm_type=augm, size=img_size, config_dict=id_config)
    else:
        raise ValueError(f'Dataset {hps.datset} not supported')

    if od_dataset == 'ssl_set':
        od_loader = dl.get_openImages('train', batch_size=od_bs, shuffle=True, augm_type=augm, size=img_size,
                                      exclude_dataset=True, config_dict=od_config)
    else:
        raise ValueError()
elif hps.train_type in ['CEDATargetedKL', 'CEDAKL', 'CEDATargetedKLEntropy']:
    soft_labels = True
    samples_per_class = 1200
    unlabeled_ratio = 10
    calibrate_temperature = True
    verbose_exclude = False

    teacher_model = 'CEDA_25-08-2020_10:22:12'

    ssl_config = {}
    id_config = {}
    od_config = {}
    loader_config = {'SSL config': ssl_config,'ID config': id_config, 'OD config': od_config}

    train_loader, tiny_train = ssl.get_openImages_partition(dataset_classifications_path, teacher_model, samples_per_class,
                                                            True, class_tpr_min=class_tpr_min, od_exclusion_threshold=od_exclusion_threshold,
                                                            calibrate_temperature=calibrate_temperature, verbose_exclude=verbose_exclude,
                                                            soft_labels=soft_labels, batch_size=bs, augm_type=augm, size=img_size,
                                                            exclude_imagenet100=True, id_config_dict=id_config, od_config_dict=od_config,
                                                            ssl_config=ssl_config)

else:
    id_config = {}
    loader_config = {'ID config': id_config}
    if hps.dataset == 'test_set':
        train_loader = ssl.get_ImageNet100_trainVal(train=True, batch_size=bs, shuffle=True, augm_type=augm,
                                          size=img_size, config_dict=id_config)
    else:
        raise ValueError(f'Dataset {hps.datset} not supported')

#test_loader = dl.get_ImageNet100(train=False, batch_size=train_bs, augm_type='none', out_size=img_size)
test_loader = ssl.get_ImageNet100_trainVal(train=False, batch_size=bs, augm_type='none', size=img_size)

# lr schedule
if epochs == 120:
    epoch_changes = [50, 85, 105, epochs]
    scheduler_config = schedulers.create_piecewise_consant_scheduler_config(epoch_changes,
                                                                              [1, 0.1, 0.1 ** 2, 0.1 ** 3])
elif epochs == 220:
    epoch_changes = [100, 150, 200, epochs]
    scheduler_config = schedulers.create_piecewise_consant_scheduler_config(epoch_changes,
                                                                              [1, 0.1, 0.01, 0.001])
else:
    pass
    #raise ValueError('Number of cycle_length not supported')

acet_steps = 10
adv_steps = attack_steps

inf_eps = 8 / 255
l2_eps = 1.75
l1_eps = 100

# load old density_model
if hps.continue_trained is not None:
    load_folder = hps.continue_trained[0]
    load_epoch = hps.continue_trained[1]
    if load_epoch == 'final':
        state_dict_file = f'{model_dir}/{load_folder}/final.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/final_optim.pth'
        config_dict = f'{model_dir}/{load_folder}/config.txt'
        # parser number of cycle_length
        with open(config_dict, 'r') as config_file:
            lines = config_file.readlines()
            start_epoch = -1
            for line in lines:
                if 'cycle_length:' in line:
                    start_epoch = int(line.split()[1])

            if start_epoch == -1:
                raise ValueError('Could not read cycle_length from config file')
    else:
        state_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}_optim.pth'
        start_epoch = int(load_epoch) + 1

    state_dict = torch.load(state_dict_file, map_location=device)
    optim_state_dict = torch.load(optimizer_dict_file, map_location=device)
    model.load_state_dict(state_dict)

    print(f'Continuing {load_folder} from epoch {start_epoch}')
else:
    start_epoch = 0
    optim_state_dict = None

model = ImageNetWrapper(model).to(device)

if len(hps.gpu) > 1:
    model = nn.DataParallel(model, device_ids=device_ids)

optimizer_config =optimizers.create_optimizer_config('SGD', lr, momentum=0.9, weight_decay=hps.decay,
                                                        nesterov=nesterov)
# scheduler_config = schedulers.create_scheduler_config('StepLR', 0.5, step_size=50)
# scheduler_config = schedulers.create_scheduler_config('ExponentialLR', 0.95)
# scheduler_config = schedulers.create_piecewise_consant_scheduler_config([30, 60, 80, 100, 120, 140, cycle_length], [1, 0.5, 0.2, 0.2**2, 0.2**3, 0.2**4, 0.2**5])
# scheduler_config = schedulers.create_cyclical_piecewise_consant_scheduler_config(0, 1.5, 0.1, 15, len(train_loader), [40, 55, cycle_length], [0.1, 0.1 ** 2, 0.1 **3])
# scheduler_config = schedulers.create_cyclical_piecewise_consant_scheduler_config(0, 0.2, 0.02, 40, len(train_loader), [60, 80, 100, cycle_length], [0.02, 0.004, 0.0008, 0.00016])

if bs >= 256:
    scheduler_config = schedulers.create_cosine_annealing_scheduler_config(epochs, 1e-5, len(train_loader) * epochs,
                                                                           warmup_length=len(train_loader) * 5)
else:
    scheduler_config = schedulers.create_cosine_annealing_scheduler_config(epochs, 1e-5, len(train_loader) * epochs)
# scheduler_config = schedulers.create_cyclical_scheduler_config(0, 2., 0, len(train_loader) * cycle_length / 4, midpoint=0.5, period_falloff=0.1)

if hps.train_type == 'plain':
    trainer = tt.PlainTraining(model, optimizer_config, epochs, device, lr_scheduler_config=scheduler_config,
                               saved_model_dir=model_dir, saved_log_dir=log_dir, test_epochs=1)
elif hps.train_type == 'adversarial':
    # https://arxiv.org/pdf/1906.09453.pdf
    if norm in ['l2', '2']:
        attack_config = tt.AdversarialTraining.create_id_attack_config(l2_eps, adv_steps, 0.7, 'l2', pgd=id_pgd,
                                                                       normalize_gradient=True)
    if norm in ['l1', '1']:
        attack_config = tt.AdversarialTraining.create_id_attack_config(l1_eps, adv_steps, 70, 'l1', pgd=id_pgd,
                                                                       normalize_gradient=True)
    else:
        attack_config = tt.AdversarialTraining.create_id_attack_config(inf_eps, adv_steps, 2. / 255, 'inf', pgd=id_pgd,
                                                                       normalize_gradient=True,
                                                                       momentum=0.0)  # , noise=f'uniform_{inf_eps}')

    trainer = tt.AdversarialTraining(model, attack_config, optimizer_config, epochs, device,
                                     train_clean=train_clean, attack_loss=adv_obj, lr_scheduler_config=scheduler_config,
                                     saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type in ['trades', 'TRADES']:
    if norm in ['l2', '2']:
        attack_config = tt.TRADESTraining.create_id_attack_config(l2_eps, adv_steps, 0.7, 'l2', pgd=id_pgd,
                                                                  normalize_gradient=True)
    else:
        attack_config = tt.AdversarialTraining.create_id_attack_config(inf_eps, adv_steps, 2. / 255, 'inf', pgd=id_pgd,
                                                                       normalize_gradient=True, momentum=0.0)
    trainer = tt.TRADESTraining(model, attack_config, optimizer_config, epochs, device,
                                lr_scheduler_config=scheduler_config, trades_weight=3,
                                saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type in ['AdvACET', 'ADVACET']:
    # https://arxiv.org/pdf/1906.09453.pdf
    if norm in ['l2', '2']:
        id_attack_config = tt.AdversarialACET.create_id_attack_config(3.5, adv_steps, 0.7, 'l2', pgd=id_pgd,
                                                                      normalize_gradient=True)
        od_attack_config = tt.AdversarialACET.create_od_attack_config(7.0, acet_steps, 1.0, norm='l2', pgd=od_pgd,
                                                                      normalize_gradient=True)
    if norm in ['l1', '1']:
        id_attack_config = tt.AdversarialACET.create_id_attack_config(100, adv_steps, 70, 'l1', pgd=id_pgd,
                                                                      normalize_gradient=True)
        od_attack_config = tt.AdversarialACET.create_od_attack_config(400, acet_steps, 100, norm='l1', pgd=od_pgd,
                                                                      normalize_gradient=True)
    else:
        id_attack_config = tt.AdversarialACET.create_id_attack_config(8 / 255, adv_steps, 2 / 255, 'inf', pgd=id_pgd,
                                                                      normalize_gradient=True, noise='normal_0.001')
        od_attack_config = tt.AdversarialACET.create_od_attack_config(12 / 255, acet_steps, 2 / 255, norm='inf',
                                                                      pgd=od_pgd,
                                                                      normalize_gradient=True, noise='normal_0.001')
    trainer = tt.AdversarialACET(model, id_attack_config, od_attack_config, optimizer_config, epochs, device,
                                 train_clean=train_clean, attack_loss=adv_obj, lr_scheduler_config=scheduler_config,
                                 train_obj=hps.t_obj, attack_obj=hps.a_obj,
                                 saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type in ['tradesceda', 'TRADESCEDA']:
    if norm in ['l2', '2']:
        id_attack_config = tt.TRADESCEDATraining.create_id_attack_config(0.5, adv_steps, 0.1, norm=2, pgd=id_pgd,
                                                                         normalize_gradient=True)
        od_attack_config = tt.TRADESCEDATraining.create_od_attack_config(1.0, acet_steps, 0.1, norm=2, pgd=od_pgd,
                                                                         normalize_gradient=True)
    else:
        id_attack_config = tt.TRADESCEDATraining.create_id_attack_config(8 / 255, adv_steps, 2 / 255, 'inf', pgd=id_pgd,
                                                                         normalize_gradient=True, noise='normal_0.001')
        od_attack_config = tt.TRADESCEDATraining.create_od_attack_config(8 / 255, acet_steps, 2 / 255, norm='inf',
                                                                         pgd=od_pgd,
                                                                         normalize_gradient=True, noise='normal_0.001')
    trainer = tt.TRADESCEDATraining(model, id_attack_config, od_attack_config, optimizer_config, epochs, device,
                                    lr_scheduler_config=scheduler_config, train_obj=hps.t_obj, id_trades_weight=6,
                                    od_trades_weight=10,
                                    saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDA':
    trainer = tt.CEDATraining(model, optimizer_config, epochs, device, lr_scheduler_config=scheduler_config,
                              train_obj=hps.t_obj, lam=lam, test_epochs=1, saved_model_dir=model_dir,
                              saved_log_dir=log_dir)
elif hps.train_type == 'ACET':
    # L2 disance between cifar10 and mnist is about 14 on average
    if norm in ['l2', '2']:
        od_attack_config = tt.AdversarialACET.create_od_attack_config(7.0, acet_steps, 1.0, norm='l2', pgd=od_pgd,
                                                                      normalize_gradient=True)
    else:
        od_attack_config = tt.ACETTraining.create_od_attack_config(12 / 255, acet_steps, 2 / 255, norm='inf',
                                                                   pgd=od_pgd,
                                                                   normalize_gradient=True)
    trainer = tt.ACETTraining(model, od_attack_config, optimizer_config, epochs, device, lam=lam,
                              lr_scheduler_config=scheduler_config, train_obj=hps.t_obj, attack_obj=hps.a_obj,
                              saved_model_dir=model_dir, saved_log_dir=log_dir, test_epochs=1)
else:
    raise ValueError('Train type {} is not supported'.format(hps.train_type))

##DEBUG:
# torch.autograd.set_detect_anomaly(True)


find_lr = False

# run training
# run training
torch.backends.cudnn.benchmark = True
if trainer.requires_out_distribution():

    # od_loader = dl.TinyImages('CIFAR10', batch_size=train_bs, shuffle=True, train=True)
    # od_loader = dl.TinyImagesOffsetLinear(batch_size=train_bs, augm=True)
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                              out_distribution_loader=od_loader)
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)

    # od_noise_dataset = dl.SmoothNoiseDataset(1.0, 2.5, (3, 32, 32), len(trainset))
    # od_noise_loader = torch.utils.ref_data.DataLoader(od_noise_dataset, batch_size=train_bs, shuffle=True, num_workers=8)
else:
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader)
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)


