import optax
from jax import numpy as jnp
import jax

# from ProtLig_GPCRclassA.losses import focal_loss

def make_loss_func(is_weighted, option = 'cross_entropy', num_classes = 3):
    if option == 'cross_entropy':
        if num_classes == 2:
            def _main_loss_func(logits, labels):
                logits = jnp.squeeze(logits['_main_label'])
                labels = jnp.asarray(labels['_main_label'], dtype = jnp.float32)
                loss_val = optax.sigmoid_binary_cross_entropy(logits, labels)
                return loss_val
        else:
            def _main_loss_func(logits, labels):
                labels = jnp.asarray(labels['_main_label'], dtype = jnp.float32)
                one_hot_labels = jax.nn.one_hot(labels, num_classes = num_classes)
                loss_val = optax.softmax_cross_entropy(logits['_main_label'], one_hot_labels)
                return loss_val
            
    elif option == 'l2_loss':
        def _main_loss_func(logits, labels):
           targets = jnp.asarray(labels['_main_label'], dtype = jnp.float32)
           loss_val = optax.l2_loss(logits['_main_label'], targets)
           return loss_val

    # Weighted:
    if is_weighted:
        def main_loss_func(logits, labels):
            loss_val = _main_loss_func(logits, labels)
            weighted_loss_val = labels['_main_sample_weight'] * loss_val
            return jnp.mean(weighted_loss_val)
    else:
        def main_loss_func(logits, labels):
            loss_val = _main_loss_func(logits, labels)
            return jnp.mean(loss_val)
    
    return main_loss_func



def make_aux_loss_func(option):
    if option == 'aux_broadness_weight':
        alpha = 1.0
        def aux_loss_func(logits, labels):
            _logits_broadness_weight = jnp.squeeze(logits['broadness_weight'])
            _labels_broadness_weight = jnp.asarray(labels['broadness_weight'], dtype = jnp.float32)
            aux_loss_val = optax.l2_loss(_logits_broadness_weight, _labels_broadness_weight)
            return alpha * jnp.mean(aux_loss_val)

    elif option == 'aux_Pyrfume':
        alpha = 10.0
        def aux_loss_func(logits, labels):
            _logits_odor = logits['Pyrfume_values']
            _labels_odor = jnp.asarray(labels['Pyrfume_values'], dtype = jnp.float32)
            mask_odor = jnp.asarray(labels['Pyrfume_values_mask'], dtype = bool)
            _weight_odor = jnp.asarray(labels['Pyrfume_weight'], dtype = jnp.float32)
            _weight_mol = jnp.asarray(labels['Pyrfume_mol_weight'], dtype = jnp.float32)
            aux_loss_odor = _weight_odor * optax.sigmoid_binary_cross_entropy(_logits_odor, _labels_odor)
            aux_loss_odor = jnp.mean(aux_loss_odor, axis = -1) * mask_odor # mean through multilabels
            aux_loss_odor = _weight_mol * aux_loss_odor
            return alpha * (jnp.sum(aux_loss_odor / jnp.sum(mask_odor)))

    if option == 'aux_MLM':
        alpha = 1.0
        def aux_loss_func(logits, labels):
            _logits_mlm = logits['_mlm_logits']
            _labels_mlm = jnp.asarray(labels['_input_ids_label'], dtype = jnp.float32)
            _one_hot_labels_mlm = jax.nn.one_hot(_labels_mlm, num_classes = _logits_mlm.shape[-1])
            aux_loss_val = optax.softmax_cross_entropy(_logits_mlm, _one_hot_labels_mlm)
            return alpha * jnp.mean(aux_loss_val)
        
    if option == 'aux_conc_derivative':
        alpha = 1.0
        def aux_loss_func(logits, labels):
            raise NotImplementedError('moved to train step instead...')

    elif option == 'aux_broadness_weight_Pyrfume':
        alpha = 0.5
        if num_classes == 2:
            pass
        else:
            raise ValueError('Option: {} is not available with more than 2 classes'.format(option))

    return aux_loss_func




    # raise NotImplementedError('see comment below:')
    # """
    # (OK) - Change loss_func to two parts: main_loss and auxiliary loss.
    # (OK) - put main sample weight to label dictionary.
    # - put auxiliary loss to metrics output.
    # """


    # # Weighted:
    # if is_weighted:
    #     def loss_func(logits, labels):
    #         labels, sample_weights = labels
    #         loss_val = _loss_func(logits, labels)
    #         weighted_loss_val = sample_weights * loss_val # <---- change sample weight!! 
    #         return jnp.mean(weighted_loss_val)
    # else:
    #     def loss_func(logits, labels):
    #         loss_val = _loss_func(logits, labels)
    #         return jnp.mean(loss_val)

    

    #     # Binary:
    #     if num_classes == 2:
    #         if is_weighted:
    #             def loss_func(logits, labels):
    #                 labels, sample_weights = labels
    #                 logits = jnp.squeeze(logits)
    #                 labels = jnp.asarray(labels, dtype = jnp.float32)
    #                 loss_val = optax.sigmoid_binary_cross_entropy(logits, labels)
    #                 weighted_loss_val = sample_weights * loss_val
    #                 return jnp.mean(weighted_loss_val)
    #         else:
    #             def loss_func(logits, labels):
    #                 logits = jnp.squeeze(logits)
    #                 labels = jnp.asarray(labels, dtype = jnp.float32)
    #                 loss_val = optax.sigmoid_binary_cross_entropy(logits, labels)
    #                 return jnp.mean(loss_val)
    #     # Multiclass:
    #     else:
    #         if is_weighted:
    #             def loss_func(logits, labels):
    #                 labels, sample_weights = labels
    #                 labels = jnp.asarray(labels, dtype = jnp.float32)
    #                 # logits = jnp.squeeze(logits)
    #                 one_hot_labels = jax.nn.one_hot(labels, num_classes = num_classes)
    #                 loss_val = optax.softmax_cross_entropy(logits, one_hot_labels)
    #                 weighted_loss_val = sample_weights * loss_val
    #                 return jnp.mean(weighted_loss_val)
    #         else:
    #             def loss_func(logits, labels):
    #                 labels = jnp.asarray(labels, dtype = jnp.float32)
    #                 # logits = jnp.squeeze(logits)
    #                 one_hot_labels = jax.nn.one_hot(labels, num_classes = num_classes)
    #                 loss_val = optax.softmax_cross_entropy(logits, one_hot_labels)
    #                 return jnp.mean(loss_val)

    # elif option == 'focal':
    #     if is_weighted:
    #         def loss_func(logits, labels):
    #             labels, sample_weights = labels
    #             logits = jnp.squeeze(logits)
    #             labels = jnp.asarray(labels, dtype = jnp.float32)
    #             loss_val = focal_loss(logits, labels)
    #             weighted_loss_val = sample_weights * loss_val
    #             return jnp.mean(weighted_loss_val)
    #     else:
    #         def loss_func(logits, labels):
    #             logits = jnp.squeeze(logits)
    #             labels = jnp.asarray(labels, dtype = jnp.float32)
    #             loss_val = focal_loss(logits, labels)
    #             return jnp.mean(loss_val)
    # return loss_func