"""==================================================================================================="""
################### LIBRARIES ###################
### Basic Libraries
import warnings
warnings.filterwarnings("ignore")

import os, sys, numpy as np, argparse, imp, datetime, pandas as pd, copy
import time, pickle as pkl, random, json, collections, itertools as it

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

from tqdm import tqdm

### DML-specific Libraries
import parameters    as par
import utilities.misc as misc



"""==================================================================================================="""
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()

parser = par.basic_training_parameters(parser)
parser = par.batch_creation_parameters(parser)
parser = par.batchmining_specific_parameters(parser)
parser = par.loss_specific_parameters(parser)
parser = par.log_parameters(parser)
parser = par.parade_parameters(parser)

##### Read in parameters
opt = parser.parse_args()

"""==================================================================================================="""
### Load Remaining Libraries that need to be loaded after comet_ml
import torch, torch.nn as nn, torch.distributed as distributed, torch.multiprocessing as mp
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel as DDP
import architectures as archs
import datasampler   as dsamplers
import datasets      as datasets
import criteria      as criteria
import metrics       as metrics
import batchminer    as bmine
import evaluation    as eval
from utilities import misc
from utilities import logger

logger.print_args(opt)

"""==================================================================================================="""
full_training_start_time = time.time()



"""==================================================================================================="""
opt.source_path += '/'+opt.dataset
opt.save_path   += '/'+opt.dataset

#Assert that the construction of the batch makes sense, i.e. the division into class-subclusters.
assert not opt.bs%opt.samples_per_class, 'Batchsize needs to fit number of samples per class for distance sampling and margin/triplet loss!'

opt.pretrained = not opt.not_pretrained
opt.method = 'saparade'


"""==================================================================================================="""
################### GPU SETTINGS ###########################
os.environ["CUDA_DEVICE_ORDER"]   ="PCI_BUS_ID"
# if not opt.use_data_parallel:
os.environ["CUDA_VISIBLE_DEVICES"]= ','.join([str(id) for id in opt.gpu]) 



"""==================================================================================================="""
#################### SEEDS FOR REPROD. #####################
torch.backends.cudnn.deterministic=True; np.random.seed(opt.seed); random.seed(opt.seed)
torch.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed)



"""==================================================================================================="""
##################### NETWORK SETUP ##################
#NOTE: Networks that can be used: 'bninception, resnet50, resnet101, alexnet...'
#>>>>  see import pretrainedmodels; pretrainedmodels.model_names
if len(opt.gpu) > 0:
    num_devices = torch.cuda.device_count()
    device_ids = opt.gpu
    if num_devices < len(device_ids):
        raise Exception(
            '#available gpu : {} < --device_ids : {}'
                .format(num_devices, len(device_ids)))
    opt.device = torch.device('cuda:{}'.format(device_ids[0]))
    opt.device_ids = device_ids
else:
    opt.device = torch.device('cpu')
    
mfeat_net = 'multifeature_resnet50' if 'resnet' in opt.arch else 'multifeature_bninception'
model      = archs.select(mfeat_net, opt)
opt.network_feature_dim = model.feature_dim

print('{} Setup for {} with {} batchmining on {} complete with #weights: {}'.format(opt.loss.upper(), opt.arch.upper(), opt.batch_mining.upper(), opt.dataset.upper(), misc.gimme_params(model)))

if opt.fc_lr<0:
    to_optim   = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}]
else:
    all_but_fc_params = [x[-1] for x in list(filter(lambda x: 'last_linear' not in x[0], model.named_parameters()))]
    fc_params         = model.model.last_linear.parameters()
    to_optim          = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay},
                         {'params':fc_params,'lr':opt.fc_lr,'weight_decay':opt.decay}]

#####
if 'selfsimilarity' in opt.parade_features:
    selfsim_model = archs.select(mfeat_net, opt)
    selfsim_model.load_state_dict(model.state_dict())

if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
    attr_model = arch.select(mfeat_net.split('_')[-1], opt)
    to_optim += [{'params': attr_model.parameters(), 'lr':opt.lr, 'weight_decay':opt.decay}]

#####
if hasattr(opt, "device_ids"):
    ### Set up data parallel
    model = DataParallel(model, device_ids=opt.device_ids).to(opt.device)
    if 'selfsimilarity' in opt.parade_features:
        selfsim_model = DataParallel(selfsim_model, device_ids=opt.device_ids).to(opt.device)

    if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
        attr_model = DataParallel(attr_model, device_ids=opt.device_ids).to(opt.device)
else:
    _  = model.to(opt.device)
    if 'selfsimilarity' in opt.parade_features:
        _  = selfsim_model.to(opt.device)

    if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
        _  = attr_model.to(opt.device)




"""============================================================================"""
#################### DATALOADER SETUPS ##################
dataloaders = {}
datasets    = datasets.select(opt.dataset, opt, opt.source_path)

dataloaders['evaluation']       = torch.utils.data.DataLoader(datasets['evaluation'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
dataloaders['evaluation_train'] = torch.utils.data.DataLoader(datasets['evaluation_train'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)
dataloaders['testing']          = torch.utils.data.DataLoader(datasets['testing'],    num_workers=opt.kernels, batch_size=opt.bs, shuffle=False)

train_data_sampler      = dsamplers.select(opt.data_sampler, opt, datasets['training'].image_dict, datasets['training'].image_list)
datasets['training'].include_aux_augmentations = True
dataloaders['training'] = torch.utils.data.DataLoader(datasets['training'], num_workers=opt.kernels, batch_sampler=train_data_sampler)

opt.n_classes  = len(dataloaders['training'].dataset.avail_classes)




"""============================================================================"""
#################### CREATE LOGGING FILES ###############
sub_loggers = ['Train', 'Test', 'Model Grad']
LOG = logger.LOGGER(opt, sub_loggers=sub_loggers, start_new=True)


"""============================================================================"""
#################### LOSS SETUP ####################
batchminer   = bmine.select(opt.batch_mining, opt)
criterion_dict = {}

for key in opt.parade_features:
    if 'discriminative' in key:
        criterion_dict[key], to_optim = criteria.select(opt.loss, opt, to_optim, batchminer)

if len(opt.parade_decorrelations):
    criterion_dict['separation'],     to_optim  = criteria.select('adversarial_separation', opt, to_optim, None)
if 'selfsimilarity' in opt.parade_features:
    criterion_dict['selfsimilarity'], to_optim  = criteria.select(opt.parade_ssl, opt, to_optim, None)


#############
if 'shared' in opt.parade_features:
    if opt.parade_sharing=='standard':
        shared_batchminer        = bmine.select('shared_neg_distance', opt)
        criterion_dict['shared'], to_optim = criteria.select(opt.loss, opt, to_optim, shared_batchminer)
    elif opt.parade_sharing=='random':
        random_shared_batchminer = bmine.select('random_distance', opt)
        criterion_dict['shared'], to_optim = criteria.select(opt.loss, opt, to_optim, random_shared_batchminer)
    elif opt.parade_sharing=='full':
        full_shared_batchminer   = bmine.select('shared_full_distance', opt)
        criterion_dict['shared'], to_optim = criteria.select(opt.loss, opt, to_optim, full_shared_batchminer)
    else:
        raise Exception('Sharing method {} not available!'.format(opt.parade_sharing))

#############
if 'intra' in opt.parade_features:
    if opt.parade_intra =='random':
        intra_batchminer = bmine.select('intra_random', opt)
    else:
        raise Exception('Intra-Feature method {} not available!'.format(opt.parade_intra))
    criterion_dict['intra'], to_optim = criteria.select(opt.loss, opt, to_optim, intra_batchminer)
        
##############
if 'sensitive' in opt.parade_features:
    if opt.parade_sensitive == 'random':
        sensitive_batchminer = bmine.select('random', opt)
    else:
        raise Exception('Sensitive attribute method {} not available!'.format(opt.parade_sensitive))
    criterion_dict['sensitive'], to_optim = criteria.select(opt.loss, opt, to_optim, sensitive_batchminer)

#############

for key in criterion_dict.keys():
    _ = criterion_dict[key].to(opt.device)

if 'selfsimilarity' in criterion_dict:
    if hasattr(opt, 'device_ids'):
        criterion_dict['selfsimilarity'].create_memory_queue(selfsim_model.module, dataloaders['training'], opt.device, opt_key='selfsimilarity')
    else:
        criterion_dict['selfsimilarity'].create_memory_queue(selfsim_model, dataloaders['training'], opt.device, opt_key='selfsimilarity')


"""============================================================================"""
#################### OPTIM SETUP ####################
optimizer    = torch.optim.Adam(to_optim)
scheduler    = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma)


"""============================================================================"""
#################### METRIC COMPUTER ####################
opt.rho_spectrum_embed_dim = opt.embed_dim
metric_computer = metrics.paradeMetricComputer(opt.evaluation_metrics, opt)


"""============================================================================"""
################### SCRIPT MAIN ##########################
print('\n-----\n')

### If checkpointing directory available, load model, optimizer, start_epoch
start_epoch = 0
if opt.checkpoint_dir is not None:
    CHECKPOINT_PATH = '{}/checkpoint.pth.tar'.format(opt.checkpoint_dir)
    if os.path.exists(CHECKPOINT_PATH):
        checkpoint = torch.load(CHECKPOINT_PATH)
        if hasattr(opt, "device_ids"):
            model.module.load_state_dict(checkpoint['model_state_dict'])
            if 'selfsimilarity' in opt.parade_features:
                selfsim_model.module.load_state_dict(checkpoint['selfsim_model_state_dict'])
            if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
                attr_model.module.load_state_dict(checkpoint['attr_model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
            if 'selfsimilarity' in opt.parade_features:
                selfsim_model.load_state_dict(checkpoint['selfsim_model_state_dict'])
            if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
                attr_model.load_state_dict(checkpoint['attr_model_state_dict'])
        for key in criterion_dict:
            criterion_dict[key].load_state_dict(checkpoint['criterion_{}_state_dict'.format(key)])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print('Resumed model training from epoch {} out of {} epochs'.format(start_epoch, opt.n_epochs))

iter_count = 0
for epoch in range(start_epoch, opt.n_epochs):
    opt.epoch = epoch
    ### Scheduling Changes specifically for cosine scheduling
    if opt.scheduler!='none': print('Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr())))

    """======================================="""
    if train_data_sampler.requires_storage:
        train_data_sampler.precompute_indices()

    """======================================="""
    ### Train one epoch
    start = time.time()
    _ = model.train()

    loss_collect = {'train':[], 'separation':[]}
    data_iterator = tqdm(dataloaders['training'], desc='Epoch {} Training...'.format(epoch))

    for i,(class_labels, input, input_indices) in enumerate(data_iterator):
    
        features  = model(input.to(opt.device))
        features, direct_features = features

        ###################
        if 'selfsimilarity' in criterion_dict:
            with torch.no_grad():
                ### Use shuffleBN to avoid information bleeding making samples interdependent.
                forward_shuffle, backward_reorder = criterion_dict['selfsimilarity'].shuffleBN(len(features['selfsimilarity']))
                selfsim_key_features              = selfsim_model(input[forward_shuffle].to(opt.device))
                if isinstance(selfsim_key_features, tuple): selfsim_key_features = selfsim_key_features[0]
                selfsim_key_features              = selfsim_key_features['selfsimilarity'][backward_reorder]
        
        if 'sensitive' in criterion_dict:
            if opt.parade_sensitive_separate:
                attr_features              = attr_model(input.to(opt.device))
                with torch.no_grad():
                    features['sensitive'] = attr_features
                if isinstance(attr_features, tuple): attr_key_features = attr_key_features[0]
            else:
                attr_features = features['sensitive']
            attr_targets = torch.LongTensor(data_iterator.iterable.dataset.get_attribute(input_indices, attributes=[opt.parade_sensitive_attribute])).view_as(class_labels)
            
            

        ###################
        loss = 0.
        for key, feature in features.items():
            if 'discriminative' in key:
                loss_discr = criterion_dict[key](feature, class_labels)
                loss = loss + loss_discr
        if 'selfsimilarity' in criterion_dict:
            loss_selfsim = criterion_dict['selfsimilarity'](features['selfsimilarity'], selfsim_key_features)
            loss = loss + opt.parade_alpha_ssl*loss_selfsim
        if 'shared' in features:
            loss_shared = criterion_dict['shared'](features['shared'], class_labels)
            loss = loss + opt.parade_alpha_shared*loss_shared
        if 'intra' in features:
            loss_intra = criterion_dict['intra'](features['intra'], class_labels)
            loss = loss + opt.parade_alpha_intra*loss_intra
        if 'sensitive' in opt.parade_features:
            loss_sensitive = criterion_dict['sensitive'](attr_features, attr_targets)
            if not opt.parade_sensitive_separate:
                loss = loss + opt.parade_alpha_sensitive*loss_sensitive
        if 'separation' in criterion_dict:
            loss_adv = criterion_dict['separation'](features)
            loss     = loss + loss_adv

        optimizer.zero_grad()
        if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
            loss_sensitive.backward()
        loss.backward()


        ### Compute Model Gradients and log them!
        grads              = np.concatenate([p.grad.detach().cpu().numpy().flatten() for p in model.parameters() if p.grad is not None])
        grad_l2, grad_max  = np.mean(np.sqrt(np.mean(np.square(grads)))), np.mean(np.max(np.abs(grads)))
        LOG.progress_saver['Model Grad'].log('Grad L2',  grad_l2,  group='L2')
        LOG.progress_saver['Model Grad'].log('Grad Max', grad_max, group='Max')

        ### Update network weights!
        optimizer.step()

        ###
        loss_collect['train'].append(loss.item())
        if 'separation' in criterion_dict:
            loss_collect['separation'].append(loss_adv.item())
        
        if 'selfsimilarity' in criterion_dict:
            ### Update Key Network
            for model_par, key_model_par in zip(model.parameters(), selfsim_model.parameters()):
                momentum = criterion_dict['selfsimilarity'].momentum
                key_model_par.data.copy_(key_model_par.data*momentum + model_par.data*(1-momentum))

            ###
            criterion_dict['selfsimilarity'].update_memory_queue(selfsim_key_features)

        ###
        iter_count += 1

        if i==len(dataloaders['training'])-1: data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect['train'])))

        """======================================="""
        if train_data_sampler.requires_storage and train_data_sampler.update_storage:
            train_data_sampler.replace_storage_entries(features.detach().cpu(), input_indices)



    result_metrics = {'loss': np.mean(loss_collect['train'])}

    ####
    LOG.progress_saver['Train'].log('epochs', epoch)
    for metricname, metricval in result_metrics.items():
        LOG.progress_saver['Train'].log(metricname, metricval)
    LOG.progress_saver['Train'].log('time', np.round(time.time()-start, 4))



    """======================================="""
    ### Evaluate -
    _ = model.eval()
    test_dataloaders = [dataloaders['testing']]

    eval.evaluate(opt.dataset, LOG, metric_computer, test_dataloaders, model, opt, opt.evaltypes, opt.device)


    LOG.update(all=True)


    """======================================="""
    ### Learning Rate Scheduling Step
    if opt.scheduler != 'none':
        scheduler.step()

    print('\n-----\n')


    """======================================="""
    ### Checkpoint model if at checkpointing interval
    if epoch % opt.checkpoint_interval == 0 and opt.checkpoint_dir is not None:
        save_dict = {
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch+1,
            'opt': vars(opt)
        }
        for key in criterion_dict:
            save_dict['criterion_{}_state_dict'.format(key)] = criterion_dict[key].state_dict()
        if hasattr(opt, 'device_ids'):
            save_dict['model_state_dict'] = model.module.state_dict()
            if 'selfsimilarity' in opt.parade_features:
                save_dict['selfsim_model_state_dict'] = selfsim_model.module.state_dict()
            if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
                save_dict['attr_model_state_dict'] = attr_model.module.state_dict()
        else:
            save_dict['model_state_dict'] = model.state_dict()
            if 'selfsimilarity' in opt.parade_features:
                save_dict['selfsim_model_state_dict'] = selfsim_model.state_dict()
            if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
                save_dict['attr_model_state_dict'] = attr_model.state_dict()
        torch.save(save_dict, CHECKPOINT_PATH)


"""======================================================="""
### SAVE THE FINAL MODEL
if hasattr(opt, 'device_ids'):
    model = model.module
    if 'selfsimilarity' in opt.parade_features:
        selfsim_model = selfsim_model.module
    if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
        attr_model = attr_model.module

if opt.final_save_path is not None:
    save_dict = {
        'opt': vars(opt),
        'model_state_dict': model.state_dict()
    }
    if 'selfsimilarity' in opt.parade_features:
        save_dict['selfsim_model_state_dict'] = selfsim_model.state_dict()
    if 'sensitive' in opt.parade_features and opt.parade_sensitive_separate:
        save_dict['attr_model_state_dict'] = attr_model.state_dict()
    torch.save(save_dict, opt.final_save_path)
