import jax
from jax import numpy as jnp

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape_pmap

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func, make_aux_loss_func
from ProtLig_GPCRclassA.amino_GNN.make_compute_metrics import make_compute_metrics

def make_train_masked_pmap_step(loss_func, init_rngs, reg_loss_func = None, aux_loss_func = None):
    """
    """
    if aux_loss_func is not None:
        def _loss_func(logits, labels):
            return loss_func(logits, labels) + aux_loss_func(logits, labels)
    else:
        _loss_func = loss_func

    if reg_loss_func is not None:
        def train_masked_pmap_step(state, batch, input_ids_mask):
            """
            """
            state = state.replace(rngs = jax.tree_map(lambda x: jax.random.split(x)[0], state.rngs)) # update PRNGKeys
            def loss_fn(params):
                logits = state.apply_fn(params, batch[:-1], deterministic = False, input_ids_mask = input_ids_mask, rngs = state.rngs) # TODO init_rngs ???
                loss_val = _loss_func(logits = logits, labels = batch[-1]) + reg_loss_func(params)
                return loss_val, logits
            grad_fn = jax.value_and_grad(loss_fn, has_aux = True)
            vals, grads = grad_fn(state.params)
            loss, logits = vals
            grads = jax.lax.pmean(grads, axis_name='num_devices')
            loss = jax.lax.pmean(loss, axis_name='num_devices')
            state = state.apply_gradients(grads = grads) # This handles updates of opt_state and params
            return state, logits, loss, grads
    else:
        def train_masked_pmap_step(state, batch, input_ids_mask):
            """
            """
            state = state.replace(rngs = jax.tree_map(lambda x: jax.random.split(x)[0], state.rngs)) # update PRNGKeys
            def loss_fn(params):
                logits = state.apply_fn(params, batch[:-1], deterministic = False, input_ids_mask = input_ids_mask, rngs = state.rngs)
                loss_val = _loss_func(logits = logits, labels = batch[-1])
                return loss_val, logits
            grad_fn = jax.value_and_grad(loss_fn, has_aux = True)
            vals, grads = grad_fn(state.params)
            loss, logits = vals
            grads = jax.lax.pmean(grads, axis_name='num_devices')
            loss = jax.lax.pmean(loss, axis_name='num_devices')
            state = state.apply_gradients(grads = grads) # This handles updates of opt_state and params
            return state, logits, loss, grads
    # print('\n\nWARNING: train_step is not jitted.\n\n')
    return train_masked_pmap_step
    # return jax.pmap(train_masked_pmap_step, axis_name='num_devices')



def make_train_masked_pmap_epoch(is_weighted, loss_option, init_rngs, logger, 
                        aux_loss_option = None, reg_loss_func = None, 
                        loader_output_type = 'jax', num_classes = 3, mask_token_id = None, 
                        cls_token_id = None, pad_token_id = None, sep_token_id = None, 
                        unk_token_id = None, eos_token_id = None, bos_token_id = None):
    """
    Helper function to create train_epoch function.
    """
    def get_masked_tokens_and_labels(input_ids, attn_mask, seq_mask_rng):
        seq_label = input_ids.copy()
        seq_mask = jax.random.uniform(seq_mask_rng, shape = input_ids.shape) < 0.15
        cls_mask = input_ids != cls_token_id
        pad_mask = attn_mask.astype(bool)
        sep_mask = input_ids != sep_token_id
        unk_mask = input_ids != unk_token_id
        eos_mask = input_ids != eos_token_id
        bos_mask = input_ids != bos_token_id
        input_ids_mask = seq_mask * cls_mask * pad_mask * sep_mask * unk_mask * eos_mask * bos_mask
        return seq_label, input_ids_mask
    get_masked_tokens_and_labels = jax.pmap(get_masked_tokens_and_labels)

    loss_func = make_loss_func(is_weighted = is_weighted, option = loss_option, num_classes = num_classes)
    if aux_loss_option is not None:
        aux_loss_func = make_aux_loss_func(option = aux_loss_option)
    else:
        aux_loss_func = None
    compute_metrics = make_compute_metrics(loss_func = loss_func, use_jit = False, num_classes = num_classes, aux_loss_func = aux_loss_func)
    train_masked_pmap_step = make_train_masked_pmap_step(loss_func = loss_func, init_rngs = init_rngs, reg_loss_func = reg_loss_func, aux_loss_func = aux_loss_func)
    # pmap functions:
    train_masked_pmap_step = jax.pmap(train_masked_pmap_step, axis_name='num_devices')
    
    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        raise NotImplementedError('jax loader not supported with pmap.')
        def train_masked_pmap_epoch(state, loader):
            batch_metrics = []
            for i, batch in enumerate(loader):
                input_ids_label, input_ids_mask = get_masked_tokens_and_labels(input_ids = batch[0][2], 
                                                                        attn_mask = batch[0][1], 
                                                                        seq_mask_rng = state.rngs['seq_mask'])
                S = batch[0][:-1]
                G = batch[1]
                labels = batch[2]
                labels.update({'_input_ids_label' : input_ids_label})
                batch = (S, G, labels)
                state, logits, _ = train_step(state, batch, input_ids_mask)
                metrics = compute_metrics(logits, labels = batch[-1])
                batch_metrics.append(metrics)
            loader.reset()
            return state, batch_metrics
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def train_masked_pmap_epoch(state, loader):
            batch_losses = []
            for i, batch in loader.enumerate():
                batch = jax.tree_map(lambda x: tf_to_jax(x), batch)
                S, G, labels = batch
                G =  tf_to_jraph_graph_reshape_pmap(G)
                input_ids_label, input_ids_mask = get_masked_tokens_and_labels(input_ids = S[2], 
                                                                        attn_mask = S[1], 
                                                                        seq_mask_rng = state.rngs['seq_mask'])
                S = S[:-1] # Discard input_ids
                labels.update({'_input_ids_label' : input_ids_label})
                # Combine back to batch:
                batch = (S, G, labels)
                state, _, loss, _ = train_masked_pmap_step(state, batch, input_ids_mask)
                batch_losses.append(loss)
            return state, batch_losses
    return train_masked_pmap_epoch