### Standard DML criteria
from criteria import triplet, margin, oproxy, proxynca, npair
from criteria import lifted, contrastive, softmax
from criteria import angular, snr, histogram, arcface
from criteria import softtriplet, multisimilarity, quadruplet

from criteria import omultiproxy
from criteria import multiproxynca, multiproxyncanegative, multiproxyncapowerset, multisupervised
from criteria import s2sd
from criteria import labelcorr
### Basic Libs
import copy

##### Losses
losses = {'triplet': triplet,
            'margin':margin,
            'proxynca':proxynca,
            'oproxy': oproxy,
            'npair':npair,
            'angular':angular,
            'contrastive':contrastive,
            'lifted':lifted,
            'snr':snr,
            'multisimilarity':multisimilarity,
            'histogram':histogram,
            'softmax':softmax,
            'softtriplet':softtriplet,
            's2sd': s2sd,
            'arcface':arcface,
            'quadruplet':quadruplet,
            'multiproxynca': multiproxynca,
            'omultiproxy': omultiproxy,
            'multiproxyncanegative': multiproxyncanegative,
            'multiproxyncapowerset': multiproxyncapowerset,
            'multisupervised': multisupervised,
            'labelcorr': labelcorr
}

## Distinguish between class losses and multi-label losses
multiloss_keys = ['multiproxynca',
                    'omultiproxy',
                    'multiproxyncanegative',
                    'multiproxyncapowerset',
                    'multisupervised',
                    'labelcorr']

## Losses usable in both class and multi-label tasks
doubly_usable_keys = ['s2sd']

"""================================================================================================="""
def select(loss, opt, to_optim, batchminer=None):
    #####

    if loss not in losses: raise NotImplementedError('Loss {} not implemented!'.format(loss))

    if loss not in doubly_usable_keys:
        if loss not in multiloss_keys and not opt.exclusive: raise ValueError('Loss {} intended for multi-class use, not multi-label use. Please set exclusive flag.'.format(loss))
        if loss in multiloss_keys and opt.exclusive: raise ValueError('Loss {} intended for multi-label use, not multi-class use. Please drop exclusive flag.'.format(loss))

    loss_lib = losses[loss]
    if loss_lib.REQUIRES_BATCHMINER:
        if batchminer is None:
            raise Exception('Loss {} requires one of the following batch mining methods: {}'.format(loss, loss_lib.ALLOWED_MINING_OPS))
        else:
            if batchminer.name not in loss_lib.ALLOWED_MINING_OPS:
                raise Exception('{}-mining not allowed for {}-loss!'.format(batchminer.name, loss))


    loss_par_dict  = {'opt':opt}
    if loss_lib.REQUIRES_BATCHMINER:
        loss_par_dict['batchminer'] = batchminer

    criterion = loss_lib.Criterion(**loss_par_dict)

    if loss_lib.REQUIRES_OPTIM:
        if hasattr(criterion,'optim_dict_list') and criterion.optim_dict_list is not None:
            to_optim += criterion.optim_dict_list
        else:
            to_optim    += [{'params':criterion.parameters(), 'lr':criterion.lr}]

    return criterion, to_optim
