"""==================================================================================================="""
################### LIBRARIES ###################
### Basic Libraries
from cProfile import label
import warnings
warnings.filterwarnings("ignore")

import os, numpy as np, argparse
import time, random, datetime, pathlib
import getpass

from tqdm import tqdm

import parameters    as par


"""==================================================================================================="""
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()

parser = par.basic_training_parameters(parser)
parser = par.batch_creation_parameters(parser)
parser = par.batchmining_specific_parameters(parser)
parser = par.loss_specific_parameters(parser)
parser = par.multimodal_parameters(parser)
parser = par.s2sd_parameters(parser)
parser = par.log_parameters(parser)

##### Read in parameters
opt = parser.parse_args()

"""==================================================================================================="""
### Load Remaining Libraries that neeed to be loaded after comet_ml
import torch, torch.nn as nn
from torch.nn import DataParallel
import architectures as archs
import datasampler   as dsamplers
import datasets      as datasets
import criteria      as criteria
import metrics       as metrics
import batchminer    as bmine
import evaluation    as eval
from utilities import misc
from utilities import logger

##### Print parameters
logger.print_args(opt)

"""==================================================================================================="""
full_training_start_time = time.time()



"""==================================================================================================="""
opt.source_path = os.path.join(opt.source_path, opt.dataset)

## Use slurm job id as checkpointing directory if available
if "SLURM_JOBID" in os.environ and opt.checkpoint_dir is None and opt.checkpoint_interval != -1:
    opt.checkpoint_dir = "/checkpoint/{}/{}".format(getpass.getuser(), os.environ["SLURM_JOBID"])
    os.makedirs(opt.checkpoint_dir, exist_ok=True)

#Ensure save path and source path exist
# os.makedirs(opt.source_path, exist_ok=True)
os.makedirs(opt.save_path, exist_ok=True)

#Assert that the construction of the batch makes sense, i.e. the division into class-subclusters.
assert not opt.bs%opt.samples_per_class, 'Batchsize needs to fit number of samples per class for distance sampling and margin/triplet loss!'

opt.pretrained = not opt.not_pretrained

"""==================================================================================================="""
################### GPU SETTINGS ###########################
# os.environ["CUDA_DEVICE_ORDER"]   ="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]= ','.join([str(id) for id in opt.gpu])



"""==================================================================================================="""
#################### SEEDS FOR REPROD. #####################
torch.backends.cudnn.deterministic=True; np.random.seed(opt.seed); random.seed(opt.seed)
torch.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed)



"""==================================================================================================="""
##################### NETWORK SETUP ##################
model      = archs.select(opt.arch, opt)

##################### DEVICE SETUP ##################
if len(opt.gpu) > 0:
    num_devices = torch.cuda.device_count()
    device_ids = opt.gpu
    if num_devices < len(device_ids):
        raise Exception(
            '#available gpu : {} < --device_ids : {}'
                .format(num_devices, len(device_ids)))
    opt.device = torch.device('cuda:{}'.format(device_ids[0]))
    opt.device_ids = device_ids
else:
    opt.device = torch.device('cpu')
#######################################################
if opt.fc_lr<0:
    to_optim   = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}]
else:
    all_but_fc_params = [x[-1] for x in list(filter(lambda x: 'last_linear' not in x[0], model.named_parameters()))]
    fc_params         = model.model.last_linear.parameters()
    to_optim          = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay},
                         {'params':fc_params,'lr':opt.fc_lr,'weight_decay':opt.decay}]

if hasattr(opt, "device_ids"):
    model = DataParallel(model, device_ids=opt.device_ids).to(opt.device)
else:
    _  = model.to(opt.device)




"""============================================================================"""
#################### DATALOADER SETUPS ##################
dataloaders = {}
datasets    = datasets.select(opt.dataset, opt, opt.source_path)

dataloaders['evaluation'] = torch.utils.data.DataLoader(datasets['evaluation'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False, drop_last=True)
dataloaders['testing']    = torch.utils.data.DataLoader(datasets['testing'],    num_workers=opt.kernels, batch_size=opt.bs, shuffle=False, drop_last=True)
if opt.use_tv_split:
    dataloaders['validation'] = torch.utils.data.DataLoader(datasets['validation'], num_workers=opt.kernels, batch_size=opt.bs,shuffle=False)

train_data_sampler      = dsamplers.select(opt.data_sampler, opt, datasets['training'].image_dict, datasets['training'].image_list)
if train_data_sampler.requires_storage:
    train_data_sampler.create_storage(dataloaders['evaluation'], model, opt.device)

dataloaders['training'] = torch.utils.data.DataLoader(datasets['training'], num_workers=opt.kernels, batch_sampler=train_data_sampler)

opt.n_classes  = len(dataloaders['training'].dataset.avail_classes)

"""============================================================================"""
#################### CREATE LOGGING FILES ###############
sub_loggers = ['Train', 'Test', 'Model Grad']
if opt.use_tv_split: sub_loggers.append('Val')

LOG = logger.LOGGER(opt, sub_loggers=sub_loggers, start_new=True, log_online=opt.log_online)

if opt.log_online:
    import wandb
    hp = vars(opt)
    hp['job_id'] = logger.get_slurm_id()
    hp['run_date'] = datetime.datetime.now()
    run_name = logger.get_run_name(hp)
    logger.init_or_resume_wandb_run(pathlib.Path(opt.save_path)/f'wandb_{run_name}.txt' , run_name = run_name, config = hp)



"""============================================================================"""
#################### LOSS SETUP ####################
batchminer = bmine.select(opt.batch_mining, opt)
criterion, to_optim = criteria.select(opt.loss, opt, to_optim, batchminer)
_ = criterion.to(opt.device)

if 'criterion' in train_data_sampler.name:
    train_data_sampler.internal_criterion = criterion

if opt.loss_labelcorr:
    label_criterion, to_optim = criteria.select("labelcorr", opt, to_optim, batchminer)
    _ = label_criterion.to(opt.device)

"""============================================================================"""
#################### OPTIM SETUP ####################
if opt.optim == 'adam':
    optimizer    = torch.optim.Adam(to_optim)
elif opt.optim == 'sgd':
    optimizer    = torch.optim.SGD(to_optim, momentum=0.9)
else:
    raise Exception('Optimizer <{}> not available!'.format(opt.optim))

if opt.scheduler == 'step':
    scheduler    = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma)
elif opt.scheduler == 'plateau':
    scheduler    = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=opt.gamma, patience=opt.patience)
elif opt.scheduler == 'none':
    pass
else:
    raise Exception('Scheduler <{}> not available!'.format(opt.scheduler))




"""============================================================================"""
#################### METRIC COMPUTER ####################
opt.rho_spectrum_embed_dim = opt.embed_dim
if opt.exclusive:
    metric_computer = metrics.MetricComputer(opt.evaluation_metrics, opt)
else:
    metric_computer = metrics.MetricComputer(opt.evaluation_metrics, opt)




"""============================================================================"""
################### Summary #########################3
data_text  = 'Dataset:\t {}'.format(opt.dataset.upper())
setup_text = 'Objective:\t {}'.format(opt.loss.upper())
miner_text = 'Batchminer:\t {}'.format(opt.batch_mining if criterion.REQUIRES_BATCHMINER else 'N/A')
arch_text  = 'Backbone:\t {} (#weights: {})'.format(opt.arch.upper(), misc.gimme_params(model))
summary    = data_text+'\n'+setup_text+'\n'+miner_text+'\n'+arch_text
print(summary)




"""============================================================================"""
################### SCRIPT MAIN ##########################
print('\n-----\n')

iter_count = 0
val_losses = []
loss_args  = {'batch':None, 'labels':None, 'batch_features':None, 'f_embed':None}

### If checkpointing directory available, load model, optimizer, start_epoch
start_epoch = 0

if opt.checkpoint_dir is not None:
    CHECKPOINT_PATH = '{}/checkpoint.pth.tar'.format(opt.checkpoint_dir)
    if os.path.exists(CHECKPOINT_PATH):
        checkpoint = torch.load(CHECKPOINT_PATH)
        if hasattr(opt, "device_ids"):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        criterion.load_state_dict(checkpoint['criterion_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        val_losses = checkpoint['val_losses']
        print('Resumed model training from epoch {} out of {} epochs'.format(start_epoch, opt.n_epochs))

if start_epoch == 0:
    """======================================="""
    ### Evaluate Metric for Training & Test (& Validation) BEFORE training(epoch = -1)
    opt.epoch = -1
    LOG.progress_saver['Train'].log('epochs', -1)
    
    print("\nLogging metrics prior to training...")
    _ = model.eval()
    print('\nComputing Testing Metrics...')
    eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['testing']],  model, opt, opt.evaltypes, opt.device, log_key='Test')
    if opt.use_tv_split:
        print('\nComputing Validation Metrics...')
        eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['validation']], model, opt, opt.evaltypes, opt.device, log_key='Val')
    if not opt.no_train_metrics:
        print('\nComputing Training Metrics...')
        eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['evaluation']], model, opt, opt.evaltypes, opt.device, log_key='Train')

    LOG.update(all=True)

for epoch in range(start_epoch, opt.n_epochs):
    epoch_start_time = time.time()

    if epoch>0 and opt.data_idx_full_prec and train_data_sampler.requires_storage:
        train_data_sampler.full_storage_update(dataloaders['evaluation'], model, opt.device)

    opt.epoch = epoch
    ### Scheduling Changes
    if opt.scheduler!='none' and opt.scheduler!='plateau': print('Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr())))
    if opt.scheduler=='plateau': print('Running with learning rates {}...'.format(getattr(scheduler, '_last_lr', opt.lr))) 

    """======================================="""
    if train_data_sampler.requires_storage:
        train_data_sampler.precompute_indices()


    """======================================="""
    ### Train one epoch
    start = time.time()
    _ = model.train()


    loss_collect = []
    data_iterator = tqdm(dataloaders['training'], desc='Epoch {} Training...'.format(epoch))

    for i,out in enumerate(data_iterator):
        class_labels, input_indices = out['labels'], out['idx']            
            
        for i in out:
            out[i] = out[i].to(opt.device)

        model_args = out
            
        # Needed for MixManifold settings.
        if 'mix' in opt.arch: model_args['labels'] = class_labels
        embeds  = model(**model_args)
        if isinstance(embeds, tuple): embeds, (avg_features, features) = embeds
             
        ### Compute Loss
        loss_args['batch']          = embeds
        loss_args['labels']         = class_labels
        if hasattr(opt, "device_ids"):
            loss_args['f_embed']    = model.module.model.last_linear
            loss_args['model']          = model.module
        else:
            loss_args['f_embed']        = model.model.last_linear
            loss_args['model']          = model
        loss_args['batch_features'] = features
        loss_args['avg_batch_features'] = avg_features
        loss      = criterion(**loss_args)
        
        if opt.loss_labelcorr:
            if "s2sd" not in opt.loss:
                loss_proxies = criterion.proxies
            else:
                loss_proxies = criterion.source_criterion.proxies
            label_loss = label_criterion(batch=loss_args['batch'], labels=loss_args['labels'], proxies=loss_proxies)
            loss += opt.loss_labelcorr_w * label_loss

        ###
        optimizer.zero_grad()
        loss.backward()

        ### Compute Model Gradients and log them!
        grads              = np.concatenate([p.grad.detach().cpu().numpy().flatten() for p in model.parameters() if p.grad is not None])
        grad_l2, grad_max  = np.mean(np.sqrt(np.mean(np.square(grads)))), np.mean(np.max(np.abs(grads)))
        LOG.progress_saver['Model Grad'].log('Grad L2',  grad_l2,  group='L2')
        LOG.progress_saver['Model Grad'].log('Grad Max', grad_max, group='Max')
        
        ### Update network weights!
        optimizer.step()

        ###
        loss_collect.append(loss.item())

        ###
        iter_count += 1

        if i==len(dataloaders['training'])-1: data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect)))

        """======================================="""
        if train_data_sampler.requires_storage and train_data_sampler.update_storage:
            train_data_sampler.replace_storage_entries(embeds.detach().cpu(), input_indices)

    result_metrics = {'loss': np.mean(loss_collect)}

    ####
    LOG.progress_saver['Train'].log('epochs', epoch)
    for metricname, metricval in result_metrics.items():
        LOG.progress_saver['Train'].log(metricname, metricval)
    LOG.progress_saver['Train'].log('time', np.round(time.time()-start, 4))



    """======================================="""
    ### Evaluate Metric for Training & Test (& Validation)
    _ = model.eval()
    print('\nComputing Testing Metrics...')
    eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['testing']],    model, opt, opt.evaltypes, opt.device, log_key='Test')
    if opt.use_tv_split:
        print('\nComputing Validation Metrics...')
        eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['validation']], model, opt, opt.evaltypes, opt.device, log_key='Val')
    if not opt.no_train_metrics:
        print('\nComputing Training Metrics...')
        eval.evaluate(opt.dataset, LOG, metric_computer, [dataloaders['evaluation']], model, opt, opt.evaltypes, opt.device, log_key='Train')


    LOG.update(all=True)


    """======================================="""
    ### Learning Rate Scheduling Step
    if opt.scheduler != 'none' and opt.scheduler != 'plateau':
        scheduler.step()
    
    if opt.scheduler == 'plateau':
        assert opt.use_tv_split, "Reduce On Plateau relies on validation metrics. Please run with 'use_tv_split' flag."
        val_loss = eval.loss_calc(datasets['validation'], model, opt, opt.device, criterion)
        val_losses.append(val_loss)
        print('\nMean Validation Loss: {}'.format(val_loss))
        scheduler.step(val_loss)

    print('Total Epoch Runtime: {0:4.2f}s'.format(time.time()-epoch_start_time))
    print('\n-----\n')
    
    """======================================="""
    ### Checkpoint model if at checkpointing interval
    if epoch % opt.checkpoint_interval == 0 and opt.checkpoint_dir is not None:
        if hasattr(opt, 'device_ids'):
            torch.save({
            'model_state_dict': model.module.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_losses': val_losses,
            'epoch': epoch+1}, CHECKPOINT_PATH)
        else:
            torch.save({
            'model_state_dict': model.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_losses': val_losses,
            'epoch': epoch+1}, CHECKPOINT_PATH)

    if len(val_losses):
        min_epoch = np.argmin(val_losses)
        if (epoch - min_epoch) >= 10:
            opt.n_epoch = epoch
            break



"""======================================================="""
### SAVE THE FINAL MODEL
if hasattr(opt, 'device_ids'):
    model = model.module

final_save_path = os.path.join(opt.save_path, 'model.pt')
torch.save({
        'opt': {key: value for key, value in vars(opt).items() if key != "device"},
        'model_state_dict': model.state_dict()}, final_save_path)

with open(os.path.join(opt.save_path, 'done'), 'w') as f:
    f.write('done')
