import matplotlib as mpl
mpl.use('Agg')

import os
import torch
import torch.nn as nn

from utils.model_normalization import Cifar10Wrapper
import utils.datasets as dl
import utils.models.model_factory_32 as factory
import utils.run_file_helpers as rh
from distutils.util import strtobool
import ssl_utils as ssl
import utils.train_types as tt

import ssl_estimation
import classify_other_dataset

import argparse
import subprocess

parser = argparse.ArgumentParser(description='Define hyperparameters.', prefix_chars='-')
parser.add_argument('--net', type=str, default='ResNet18', help='Resnet18, 34 or 50, WideResNet28')
parser.add_argument('--model_params', nargs='+', default=[])
parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10 or semi-cifar10')
parser.add_argument('--cifar_subset', type=int, default='0', help='Use subset of X cifar images')
parser.add_argument('--od_dataset', type=str, default='tinyImages')
parser.add_argument('--od_validation_set', type=str, default=None)
parser.add_argument('--exclude_cifar', dest='exclude_cifar', type=lambda x: bool(strtobool(x)),
                    default=True, help='whether to exclude cifar10 from tiny images')
parser.add_argument('--CEDA_label_smoothing', default=0, type=float,help='Whether to weight the od')
parser.add_argument('--semi_ratio', type=int,
                    default=0, help='Fixed ratio or variable (0)')
parser.add_argument('--start_samples', type=int,
                    default=2_000, help='Start additional samples per class (first iteration)')
parser.add_argument('--additional_samples', type=int,
                    default=2_000, help='Samples factor')
parser.add_argument('--iterations', type=int,
                    default=5, help='Max additional samples per class')
parser.add_argument('--unlabeled_samples', type=int,
                    default=1_000_000, help='Max additional samples per class')
parser.add_argument('--od_validation_samples', type=int,
                    default=2_000)
parser.add_argument('--teacher', type=str,
                    default=None, help='Teacher density_model')
parser.add_argument('--threshold', type=str,
                    default='0.980', help='TPR threshold')
parser.add_argument('--od_threshold', type=str,
                    default='same', help='OD threshold')
parser.add_argument('--calibrate_temperature', dest='calibrate_temperature', type=lambda x: bool(strtobool(x)),
                    default=True, help='whether to use temperature calibration')

rh.parser_add_commons(parser)
rh.parser_add_adversarial_commons(parser)
rh.parser_add_adversarial_norms(parser, 'cifar10')

hps = parser.parse_args()
#
device_ids = None
if len(hps.gpu)==0:
    device = torch.device('cpu')
    print('Warning! Computing on CPU')
elif len(hps.gpu)==1:
    device = torch.device('cuda:' + str(hps.gpu[0]))
else:
    device_ids = [int(i) for i in hps.gpu]
    device = torch.device('cuda:' + str(min(device_ids)))

#start with a pre defined teacher
teacher_model = hps.teacher
selection_model = None

# parameters
# https://arxiv.org/pdf/1906.09453.pdf
t_obj = 'kl'
lr = hps.lr
bs = hps.bs
epochs = hps.epochs
lam = 1.0

network_name = hps.net.lower()
augm = hps.augm.lower()
exclude_cifar = hps.exclude_cifar
nesterov = hps.nesterov
od_dataset = hps.od_dataset
od_validation_set = hps.od_dataset if hps.od_validation_set is None else hps.od_validation_set
ceda_label_smoothing = hps.CEDA_label_smoothing
warmup_epochs = hps.warmup_epochs
test_epochs = hps.test_epochs

num_classes = 10

#Load density_model
img_size = 32
model_root_dir = 'Cifar10Models'
logs_root_dir = 'Cifar10Logs'

if hps.ema:
    model_checkpoint = 'best_avg'
else:
    model_checkpoint = 'best'

for iteration in range(hps.iterations):
    print('Estimating thresholds:')
    ssl_estimation.main(hps.gpu, teacher_model, network_name, model_checkpoint, 'cifar10', hps.cifar_subset,
                        od_dataset, od_validation_set,
                        hps.unlabeled_samples, hps.od_validation_samples, thresholds_only=True)

    print('Classifying unlabeled pool:')
    classify_other_dataset.main(hps.gpu, teacher_model, network_name, model_checkpoint, 'cifar10', hps.cifar_subset,
                                od_dataset,  hps.unlabeled_samples)

    model, model_name, model_config = factory.build_model(network_name, num_classes)
    model_dir = os.path.join(model_root_dir, model_name)
    log_dir = os.path.join(logs_root_dir, model_name)

    # load dataset
    od_bs = int(hps.od_bs_factor * bs)

    class_tpr_min = hps.threshold
    if hps.od_threshold == 'same':
        od_exclusion_threshold = class_tpr_min
    elif hps.od_threshold == 'none':
        od_exclusion_threshold = None
    else:
        od_exclusion_threshold = hps.od_threshold

    cutout_window = 16

    dataset_classifications_path = ssl.get_dataset_classification_dir('cifar10')
    epoch_subdivs = 1

    msda_config = rh.create_msda_config(hps)

    if hps.train_type in ['plainKL', 'CEDATargetedKL', 'CEDAKL', 'CEDATargeted', 'CEDATargetedKLEntropy']:
        samples_per_class = int(hps.start_samples + iteration * hps.additional_samples)

        print(f'\n\n#################################')
        print(f'Iteration {iteration} - Samples {samples_per_class}')
        print(f'#################################\n\n')

        calibrate_temperature = hps.calibrate_temperature
        verbose_exclude = False

        ssl_config = {}
        id_config = {}
        od_config = {}

        if hps.train_type in  ['plainKL']:
            loader_config = {'SSL config': ssl_config, 'ID config': id_config}
        else:
            loader_config = {'SSL config': ssl_config,'ID config': id_config, 'OD config': od_config}

        if hps.train_type in  ['plainKL', 'CEDATargetedKL', 'CEDAKL'] :
            soft_labels = True
        elif hps.train_type == 'CEDATargeted':
            soft_labels = False
        else:
            raise NotImplementedError()

        if od_dataset == 'tinyImages':
            if hps.cifar_subset <= 0:
                train_loader, od_loader = ssl.get_tiny_cifar_partition(dataset_classifications_path, teacher_model, 'cifar10',
                                                                       samples_per_class, False, semi_ratio=hps.semi_ratio,
                                                                       class_tpr_min=class_tpr_min,
                                                                       od_exclusion_threshold=od_exclusion_threshold,
                                                                       calibrate_temperature=calibrate_temperature,
                                                                       verbose_exclude=verbose_exclude,
                                                                       soft_labels=soft_labels, batch_size=bs,
                                                                       augm_type=augm, aa_magnitude=1.0, size=img_size,
                                                                       exclude_cifar=exclude_cifar,
                                                                       exclude_cifar10_1=exclude_cifar,
                                                                       id_config_dict=id_config, od_config_dict=od_config,
                                                                       ssl_config=ssl_config)
            else:
                raise NotImplementedError('use tinyImages_subset ')
        else:
            train_loader, od_loader = ssl.get_cifar_subset_plus_od_partition(teacher_model, 'cifar10', od_dataset,
                                                                             hps.cifar_subset // 10, samples_per_class, hps.unlabeled_samples,
                                                                             semi_ratio=hps.semi_ratio, class_tpr_min=class_tpr_min,
                                                                              od_exclusion_threshold=od_exclusion_threshold,
                                                                              calibrate_temperature=calibrate_temperature,
                                                                              verbose_exclude=verbose_exclude,
                                                                              soft_labels=soft_labels,
                                                                              batch_size=bs, augm_type=augm,
                                                                              size=img_size,
                                                                              aa_magnitude=1.0,
                                                                              id_config_dict=id_config,
                                                                              od_config_dict=od_config,
                                                                              ssl_config=ssl_config)
    else:
        raise NotImplementedError()

    if hps.cifar_subset <= 0:
        test_loader = dl.get_CIFAR10_1(batch_size=bs)
    else:
        test_loader = ssl.get_CIFAR10_subset('val', hps.cifar_subset / 10, batch_size=bs, augm_type='none',
                                                      shuffle=True, size=img_size)
    extra_test_loader = dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none')

    scheduler_config, optimizer_config = rh.create_optim_scheduler_swa_configs(hps)
    total_epochs = epochs * epoch_subdivs

    # load old density_model
    if hps.continue_trained is not None:
        load_folder = hps.continue_trained[0]
        load_epoch = hps.continue_trained[1]
        start_epoch = int(int(hps.continue_trained[2]))# * epoch_subdivs)
        if load_epoch in ['final', 'best', 'final_swa', 'best_swa']:
            state_dict_file = f'{model_dir}/{load_folder}/{load_epoch}.pth'
            optimizer_dict_file = f'{model_dir}/{load_folder}/{load_epoch}_optim.pth'
        else:
            state_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}.pth'
            optimizer_dict_file = f'{model_dir}/{load_folder}/checkpoints/{load_epoch}_optim.pth'

        state_dict = torch.load(state_dict_file, map_location=device)

        if hps.continue_optim:
            try:
                optim_state_dict = torch.load(optimizer_dict_file, map_location=device)
            except:
                print('Warning: Could not load Optim State - Restarting optim')
                optim_state_dict = None
        else:
            optim_state_dict = None

        model.load_state_dict(state_dict)

        print(f'Continuing {load_folder} from epoch {load_epoch} - Starting training at epoch {start_epoch}')
    else:
        start_epoch = 0
        optim_state_dict = None

    model = Cifar10Wrapper(model).to(device)

    if len(hps.gpu) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)


    #Train Type
    if hps.train_type == 'plain':
        trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device, num_classes,
                                   lr_scheduler_config=scheduler_config,
                                   msda_config=msda_config, test_epochs=test_epochs,
                                   saved_model_dir=model_dir, saved_log_dir=log_dir)
    elif hps.train_type == 'CEDA' or hps.train_type == 'CEDA':
        trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                                  lr_scheduler_config=scheduler_config, msda_config=msda_config,
                                  train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                                  saved_model_dir=model_dir, saved_log_dir=log_dir)
    elif hps.train_type == 'CEDATargeted':
        CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
        trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                                  CEDA_variant=CEDA_VARIANT,
                                  lr_scheduler_config=scheduler_config, msda_config=msda_config,
                                  train_obj=t_obj, od_weight=lam, test_epochs=test_epochs,
                                  saved_model_dir=model_dir, saved_log_dir=log_dir)
    elif hps.train_type == 'CEDATargetedKL':
        CEDA_VARIANT = {'Type': 'CEDATargeted', 'LabelSmoothingEps' : None if ceda_label_smoothing == 0 else ceda_label_smoothing }
        trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                                  CEDA_variant=CEDA_VARIANT, lr_scheduler_config=scheduler_config, msda_config=msda_config,
                                  clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                                  test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
    elif hps.train_type == 'CEDAKL':
        trainer = tt.CEDATraining(model, optimizer_config, total_epochs, device, num_classes,
                                  lr_scheduler_config=scheduler_config, msda_config=msda_config,
                                  clean_criterion='kl', train_obj=t_obj, od_weight=lam,
                                  test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
    elif hps.train_type == 'plainKL':
        trainer = tt.PlainTraining(model, optimizer_config, total_epochs, device,  num_classes,
                                   lr_scheduler_config=scheduler_config, msda_config=msda_config, clean_criterion='kl',
                                   test_epochs=test_epochs, saved_model_dir=model_dir, saved_log_dir=log_dir)
    else:
        raise ValueError('Train type {} is not supported'.format(hps.train_type))

    ##DEBUG:
    # torch.autograd.set_detect_anomaly(True)


    torch.backends.cudnn.benchmark = True
    if trainer.requires_out_distribution():

        # od_loader = dl.TinyImages('CIFAR10', batch_size=train_bs, shuffle=True, train=True)
        # od_loader = dl.TinyImagesOffsetLinear(batch_size=train_bs, augm=True)
        train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                                  extra_test_loaders=[extra_test_loader],
                                                                  out_distribution_loader=od_loader)
        next_teacher_name = trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                      optim_state_dict=optim_state_dict)

        # od_noise_dataset = dl.SmoothNoiseDataset(1.0, 2.5, (3, 32, 32), len(trainset))
        # od_noise_loader = torch.utils.ref_data.DataLoader(od_noise_dataset, batch_size=train_bs, shuffle=True, num_workers=8)
    else:
        train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader,
                                                                  extra_test_loaders=[extra_test_loader]
                                                                  )
        next_teacher_name = trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch,
                      optim_state_dict=optim_state_dict)



    teacher_model = next_teacher_name
