import matplotlib as mpl
mpl.use('Agg')

import os
import torch
import torch.nn as nn

from utils.model_normalization import Cifar100Wrapper
import utils.datasets as dl
import utils.models.model_factory_32 as factory
import utils.run_file_helpers as rh
from distutils.util import strtobool
import ssl_utils as ssl
import utils.train_types as tt

import argparse

parser = argparse.ArgumentParser(description='Define hyperparameters.', prefix_chars='-')
parser.add_argument('--net', type=str, default='ResNet18', help='Resnet18, 34 or 50, WideResNet28')
parser.add_argument('--model_params', nargs='+', default=[])
parser.add_argument('--dataset', type=str, default='cifar100', help='cifar100')
parser.add_argument('--od_dataset', type=str, default='tinyImages',
                    help=('tinyImages or cifar100'))
parser.add_argument('--exclude_cifar', dest='exclude_cifar', type=lambda x: bool(strtobool(x)),
                    default=True, help='whether to exclude cifar10 from tiny images')
parser.add_argument('--CEDA_label_smoothing', default=0, type=float,help='Whether to weight the od')
parser.add_argument('--semi_ratio', type=int,
                    default=0, help='Fixed ratio or variable (0)')
parser.add_argument('--samples', type=int,
                    default=25000, help='Max additional samples per class')
parser.add_argument('--unlabeled_samples', type=int,
                    default=1_000_000, help='Max additional samples per class')
parser.add_argument('--teacher', type=str,
                    default=None, help='Teacher density_model')
parser.add_argument('--threshold', type=str,
                    default='0.980', help='TPR threshold')
parser.add_argument('--od_threshold', type=str,
                    default='same', help='OD threshold')

rh.parser_add_commons(parser)
rh.parser_add_adversarial_commons(parser)
rh.parser_add_adversarial_norms(parser, 'cifar10')
hps = parser.parse_args()
#
if len(hps.gpu) == 0:
    device = torch.device('cpu')
    print('Warning! Computing on CPU')
elif len(hps.gpu) == 1:
    device = torch.device('cuda:' + str(hps.gpu[0]))
else:
    device_ids = [int(i) for i in hps.gpu]
    device = torch.device('cuda:' + str(min(device_ids)))

# parameters
# https://arxiv.org/pdf/1906.09453.pdf
t_obj = 'kl'
lr = hps.lr
bs = hps.bs
epochs = hps.epochs
lam = 1.0

network_name = hps.net.lower()
augm = hps.augm.lower()
exclude_cifar = hps.exclude_cifar
nesterov = hps.nesterov
od_dataset = hps.od_dataset
ceda_label_smoothing = hps.CEDA_label_smoothing
warmup_epochs = hps.warmup_epochs
test_epochs = hps.test_epochs

num_classes = 100

img_size = 32
model_root_dir = 'Cifar100Models'
logs_root_dir = 'Cifar100Logs'

model, model_name, model_config = factory.build_model(network_name, num_classes)
model_dir = os.path.join(model_root_dir, model_name)
log_dir = os.path.join(logs_root_dir, model_name)

# load dataset
od_bs = int(hps.od_bs_factor * bs)

class_tpr_min = hps.threshold
if hps.od_threshold == 'same':
    od_exclusion_threshold = class_tpr_min
elif hps.od_threshold == 'none':
    od_exclusion_threshold = None
else:
    od_exclusion_threshold = hps.od_threshold
cutout_window = 16

epoch_subdivs = 1

msda_config = rh.create_msda_config(hps)

dataset_classifications_path =  ssl.get_dataset_classification_dir('cifar100')

#load dataset
od_bs = int(hps.od_bs_factor * bs)
if exclude_cifar:
    tiny_dataset = 'CIFAR10'
else:
    tiny_dataset = None
if hps.train_type in ['plainKL', 'CEDATargetedKL', 'CEDAKL', 'CEDATargeted', 'CEDATargetedKLEntropy']:
    samples_per_class = hps.samples
    calibrate_temperature = True
    verbose_exclude = False

    teacher_model = hps.teacher
    selection_model = None

    ssl_config = {}
    id_config = {}
    od_config = {}

    if hps.train_type in  ['plainKL']:
        loader_config = {'SSL config': ssl_config, 'ID config': id_config}
    else:
        loader_config = {'SSL config': ssl_config,'ID config': id_config, 'OD config': od_config}

    if hps.train_type in  ['plainKL', 'CEDATargetedKL', 'CEDAKL'] :
        soft_labels = True
        calibrate_temperature = True
    elif hps.train_type == 'CEDATargeted':
        soft_labels = False
        calibrate_temperature = False
    else:
        raise NotImplementedError()

    train_loader, od_loader = ssl.get_tiny_cifar_partition(dataset_classifications_path, teacher_model, 'cifar100',
                                                           samples_per_class, True, semi_ratio=hps.semi_ratio,
                                                           class_tpr_min=class_tpr_min,
                                                           od_exclusion_threshold=od_exclusion_threshold,
                                                           calibrate_temperature=calibrate_temperature,
                                                           verbose_exclude=verbose_exclude,
                                                           soft_labels=soft_labels, batch_size=bs,
                                                           augm_type=augm, aa_magnitude=1.0, size=img_size,
                                                           exclude_cifar=exclude_cifar,
                                                           exclude_cifar10_1=exclude_cifar,
                                                           id_config_dict=id_config, od_config_dict=od_config,
                                                           ssl_config=ssl_config)
elif hps.train_type in ['CEDA', 'CEDAExtra']:
    id_config = {}
    od_config = {}
    loader_config = {'ID config': id_config, 'OD config': od_config}

    train_loader = ssl.get_CIFAR100TrainValidation(train=True, batch_size=bs, augm_type=augm, config_dict=id_config)
    if od_dataset == 'tinyimages':
        od_loader = dl.get_80MTinyImages(batch_size=od_bs, augm_type=augm, cutout_window=cutout_window, num_workers=1,
                                         exclude_cifar=True, exclude_cifar10_1=False, config_dict=od_config)
    else:
        raise ValueError('OD Dataset not supported')
else:
    id_config = {}
    loader_config = {'ID config': id_config}

    train_loader = ssl.get_CIFAR100TrainValidation(train=True, batch_size=bs, augm_type=augm, config_dict=id_config)


test_loader = ssl.get_CIFAR100TrainValidation(train=False, batch_size=bs, augm_type='none')
extra_test_loader = dl.get_CIFAR100(train=False, batch_size=bs, augm_type='none')

scheduler_config, optimizer_config = rh.create_optim_scheduler_swa_configs(hps)
total_epochs = epochs * epoch_subdivs

# load old density_model
if hps.continue_trained is not None:
    load_folder = hps.continue_trained[0]
    load_epoch = hps.continue_trained[1]
    start_epoch = int(int(hps.continue_trained[2]))# * epoch_subdivs)
    if load_epoch in ['final', 'best', 'final_swa', 'best_swa']:
        state_dict_file = f'{model_dir}/{load_folder}/{load_epoch}.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/{load_epoch}_optim.pth'
    else:
        state_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}.pth'
        optimizer_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}_optim.pth'

    state_dict = torch.load(state_dict_file, map_location=device)

    try:
        optim_state_dict = torch.load(optimizer_dict_file, map_location=device)
    except:
        print('Warning: Could not load Optim State - Restarting optim')
        optim_state_dict = None
    model.load_state_dict(state_dict)

    print(f'Continuing {load_folder} from epoch {load_epoch} - Starting training at epoch {start_epoch}')
else:
    start_epoch = 0
    optim_state_dict = None

model = Cifar100Wrapper(model).to(device)

if len(hps.gpu) > 1:
    model = nn.DataParallel(model, device_ids=device_ids)


#Train Type
if hps.train_type == 'plain':
    trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device, num_classes,
                               lr_scheduler_config=scheduler_config,
                               msda_config=msda_config, test_epochs=test_epochs,
                               saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDA' or hps.train_type == 'CEDA':
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                              saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDATargeted':
    CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              CEDA_variant=CEDA_VARIANT,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                              saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDATargetedKL':
    CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              CEDA_variant=CEDA_VARIANT, lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                              test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'CEDAKL':
    trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                              lr_scheduler_config=scheduler_config, msda_config=msda_config,
                              clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                              test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
elif hps.train_type == 'plainKL':
    trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device,  num_classes,
                               lr_scheduler_config=scheduler_config, msda_config=msda_config, clean_criterion='kl',
                               test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
else:
    raise ValueError('Train type {} is not supported'.format(hps.train_type))

##DEBUG:
# torch.autograd.set_detect_anomaly(True)


torch.backends.cudnn.benchmark = True
if trainer.requires_out_distribution():

    # od_loader = dl.TinyImages('CIFAR10', batch_size=train_bs, shuffle=True, train=True)
    # od_loader = dl.TinyImagesOffsetLinear(batch_size=train_bs, augm=True)
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                              extra_test_loaders=[extra_test_loader],
                                                              out_distribution_loader=od_loader)
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)

    # od_noise_dataset = dl.SmoothNoiseDataset(1.0, 2.5, (3, 32, 32), len(trainset))
    # od_noise_loader = torch.utils.ref_data.DataLoader(od_noise_dataset, batch_size=train_bs, shuffle=True, num_workers=8)
else:
    train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                              extra_test_loaders=[extra_test_loader]
                                                              )
    trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                  optim_state_dict=optim_state_dict)

