import collections

from .losses import *
from .schedulers import *
from .utils import cycle
from .projections import linear_projection, projection_with_activation

class DynamicKeyDict:
    def __init__(self, kv_dict):
        self.store = kv_dict
    def __getitem__(self, key):
        if not isinstance(key,(list,tuple)):
            return self.store[key]
        else:
            name = key[0]
            args = key[1:]
            if len(args)==1 and isinstance(args[0],dict):
                return self.store[name](**(args[0]))
            else:
                return self.store[name](*args)
    def __setitem__(self, key, value):
        self.store[key] = value
    def __contains__(self, key):
        if isinstance(key, (list,tuple)):
            return key[0] in self.store
        else:
            return key in self.store

TEMPERATURE_SCHEDULER=DynamicKeyDict(
    {'constant': constant_temperature_scheduler,
     'flsw': flsw_temperature_scheduler_builder,
     'cwsm':cwsm_temperature_scheduler_builder})
"""
(*custom dict*) used to dynamically adjust distillation temperature.

    * '**constant**' : Constant temperature.
    * '**flsw**' :  See `Preparing Lessons: Improve Knowledge Distillation with Better Supervision <https://arxiv.org/abs/1911.07471>`_. Needs parameters ``beta`` and ``gamma``.
    * '**cwsm**': See `Preparing Lessons: Improve Knowledge Distillation with Better Supervision <https://arxiv.org/abs/1911.07471>`_. Needs parameter ``beta``.

Different from other options, when using ``'flsw'`` and ``'cwsm'``, you need to provide extra parameters, for example::

    #flsw
    distill_config = DistillationConfig(
        temperature_scheduler = ['flsw', 1， 2]  # beta=1, gamma=2
    )
    
    #cwsm
    distill_config = DistillationConfig(
        temperature_scheduler = ['cwsm', 1] # beta = 1
    )

"""



FEATURES = ['hidden','attention']


ADAPTOR_KEYS = ['logits','logits_mask','losses','inputs_mask','labels'] + FEATURES
"""
(*list*) valid keys of the dict returned by the adaptor, includes:

    * '**logits**'
    * '**logits_mask**'
    * '**losses**'
    * '**inputs_mask**'
    * '**labels**'
    * '**hidden**'
    * '**attention**'
"""


KD_LOSS_MAP = {'mse': kd_mse_loss,
                'ce': kd_ce_loss,
               'wsl': wsl_kd_loss}
"""
(*dict*) available KD losses

  * '**mse**' : mean squared error 
  * '**ce**': cross-entropy loss
"""

MATCH_LOSS_MAP = {'attention_mse_sum': att_mse_sum_loss,
                  'attention_mse': att_mse_loss,
                  'attention_ce_mean': att_ce_mean_loss,
                  'attention_ce': att_ce_loss,
                  'hidden_mse'    : hid_mse_loss,
                  'cos'  : cos_loss,
                  'pkd'  : pkd_loss,
                  'gram' : fsp_loss,
                  'fsp'  : fsp_loss,
                  'mmd'  : mmd_loss,
                  'nst'  : mmd_loss}
"""
(*dict*) intermediate feature matching loss functions, includes:

* :func:`attention_mse_sum <textbrewer.losses.att_mse_sum_loss>`
* :func:`attention_mse <textbrewer.losses.att_mse_loss>`
* :func:`attention_ce_mean <textbrewer.losses.att_ce_mean_loss>`
* :func:`attention_ce <textbrewer.losses.att_ce_loss>`
* :func:`hidden_mse <textbrewer.losses.hid_mseloss>`
* :func:`cos <textbrewer.losses.cos_loss>`
* :func:`pkd <textbrewer.losses.pkd_loss>`
* :func:`fsp <textbrewer.losses.fsp_loss>`, :func:`gram <textbrewer.losses.fsp_loss>`
* :func:`nst <textbrewer.losses.nst_loss>`, :func:`mmd <textbrewer.losses.nst_loss>`

See :ref:`intermediate_losses` for details.
"""

PROJ_MAP = {'linear': linear_projection,
            'relu'  : projection_with_activation('ReLU'),
            'tanh'  : projection_with_activation('Tanh')
            }
"""
(*dict*) layers used to match the different dimensions of intermediate features

  * '**linear**' : linear layer, no activation
  * '**relu**' : ReLU activation
  * '**tanh**': Tanh activation
"""

WEIGHT_SCHEDULER = {'linear_decay': linear_decay_weight_scheduler,
                    'linear_growth' : linear_growth_weight_scheduler}
"""
(dict) Scheduler used to dynamically adjust KD loss weight and hard_label_loss weight.

  * ‘**linear_decay**' : decay from 1 to 0 during the whole training process.
  * '**linear_growth**' : grow from 0 to 1 during the whole training process.
"""

#TEMPERATURE_SCHEDULER = {'constant': constant_temperature_scheduler,
#                         'flsw_scheduler': flsw_temperature_scheduler_builder(1,1)}


MAPS = {'kd_loss': KD_LOSS_MAP,
        'match_Loss': MATCH_LOSS_MAP,
        'projection': PROJ_MAP,
        'weight_scheduler': WEIGHT_SCHEDULER,
        'temperature_scheduler': TEMPERATURE_SCHEDULER}


def register_new(map_name, name, func):
    assert map_name in MAPS
    assert callable(func), "Functions to be registered is not callable"
    MAPS[map_name][name] = func


'''
Add new loss:
def my_L1_loss(feature_S, feature_T, mask=None):
    return (feature_S-feature_T).abs().mean()

MATCH_LOSS_MAP['my_L1_loss'] = my_L1_loss
'''
