"""

This code is based on the implementation from the FRePo repository.
Source: https://github.com/yongchaoz/FRePo

"""

import os
import logging
from typing import Any, Sequence

import numpy as np

import flax.linen as nn
import jax.scipy as sp


from .metrics import mean_squared_loss, get_metrics, top5_accuracy, soft_cross_entropy_loss, \
    cross_entropy_loss
import jax
import jax.numpy as jnp

import optax
from flax.training import train_state, checkpoints

from .metrics import pred_acurracy, top5_accuracy

Array = Any


class TrainState(train_state.TrainState):
    """
    Simple train state for the common case with a single Optax optimizer.
    Attributes:
        batch_stats (Any): Collection used to store an exponential moving
                           average of the batch statistics.
        epoch (int): Current epoch.
        best_val_acc (float): Best validation accuracy
    """
    epoch: int
    best_val_acc: float
    batch_stats: Any = None
    ema_hidden: Any = None
    ema_average: Any = None
    ema_hidden_batch: Any = None
    ema_average_batch: Any = None
    ema_count: int = 0


@jax.jit
def _bias_correction(moment, decay, count):
    """Perform bias correction. This becomes a no-op as count goes to infinity."""
    bias_correction = 1 - decay ** count
    return jax.tree_util.tree_map(lambda t: t / bias_correction.astype(t.dtype), moment)


@jax.jit
def _update_moment(updates, moments, decay, order):
    """Compute the exponential moving average of the `order`-th moment."""
    return jax.tree_util.tree_map(
        lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)


class EMA():
    def __init__(self, decay, debias: bool = True):
        """Initializes an ExponentialMovingAverage module.

        References: https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
                    https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L39#L137
        Args:
          decay: The chosen decay. Must in ``[0, 1)``. Values close to 1 result in
            slow decay; values close to ``0`` result in fast decay.
          debias: Whether to run with zero-debiasing.
        """
        self.decay = decay
        self.debias = debias
        self.hidden = None
        self.average = None
        self.count = None

    def initialize(self, state, hidden=None, average=None, count=None):
        if hidden is not None:
            assert average is not None, 'hidden and average should both be None or not None'
            self.hidden = hidden
            self.average = average
            self.count = count
        else:
            self.average = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), state)
            self.hidden = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), state)
            self.count = jnp.zeros([], jnp.int32)

    def __call__(self, value, update_stats: bool = True) -> jnp.ndarray:
        """Updates the EMA and returns the new value.
        Args:
          value: The array-like object for which you would like to perform an
            exponential decay on.
          update_stats: A Boolean, whether to update the internal state
            of this object to reflect the input value. When `update_stats` is False
            the internal stats will remain unchanged.
        Returns:
          The exponentially weighted average of the input value.
        """

        count = self.count + 1
        hidden = _update_moment(value, self.hidden, self.decay, order=1)

        average = hidden
        if self.debias:
            average = _bias_correction(hidden, self.decay, count)

        if update_stats:
            self.hidden = hidden
            self.average = average
            self.count = count

        return average

    @property
    def ema(self):
        return self.average


class AVG():
    def __init__(self):
        """Initializes an ExponentialMovingAverage module.

        References: https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
                    https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L39#L137
        Args:
          decay: The chosen decay. Must in ``[0, 1)``. Values close to 1 result in
            slow decay; values close to ``0`` result in fast decay.
          debias: Whether to run with zero-debiasing.
        """
        self.hidden = None
        self.average = None
        self.count = None

    def initialize(self, state, hidden=None, average=None, count=None):
        if hidden is not None:
            assert average is not None, 'hidden and average should both be None or not None'
            self.hidden = hidden
            self.average = average
            self.count = count
        else:
            self.average = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), state)
            self.hidden = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), state)
            self.count = jnp.zeros([], jnp.int32)

        if not isinstance(self.hidden, dict):
            self.hidden = self.hidden.unfreeze()
        if not isinstance(self.average, dict):
            self.average = self.average.unfreeze()

    def __call__(self, value, update_stats: bool = True) -> jnp.ndarray:
        """Updates the EMA and returns the new value.
        Args:
          value: The array-like object for which you would like to perform an
            exponential decay on.
          update_stats: A Boolean, whether to update the internal state
            of this object to reflect the input value. When `update_stats` is False
            the internal stats will remain unchanged.
        Returns:
          The exponentially weighted average of the input value.
        """

        count = self.count + 1
        value = value.unfreeze()
        hidden = jax.tree_util.tree_map(
            lambda h, v: h * (self.count / count) + v / count, self.hidden, value)

        average = hidden

        if update_stats:
            self.hidden = hidden
            self.average = average
            self.count = count

        return average

    @property
    def avg(self):
        return self.average


def restore_checkpoint(state, path):
    """
    Restores checkpoint with best validation score.
    Args:
        state (train_state.TrainState): Training state.
        path (str): Path to checkpoint.
    Returns:
        (train_state.TrainState): Training state from checkpoint.
    """
    return checkpoints.restore_checkpoint(path, state)


def save_checkpoint(state, path, step_or_metric=None, keep=1):
    """
    Saves a checkpoint from the given state.
    Args:
        state (train_state.TrainState): Training state.
        step_or_metric (int of float): Current training step or metric to identify the checkpoint.
        path (str): Path to the checkpoint directory.
    """
    if jax.device_count() > 1:
        if jax.process_index() == 0:
            state = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state))
        else:
            return

    if step_or_metric is None:
        checkpoints.save_checkpoint(path, state, int(state.step), keep=keep, overwrite=True)
    else:
        checkpoints.save_checkpoint(path, state, step_or_metric, keep=keep)


def make_chunky_prediction(x, predict_fn, chunk_size=1000):
    y = []
    for i in range(int(x.shape[0] / chunk_size)):
        x_chunk = x[i * chunk_size:(i + 1) * chunk_size]
        y_chunk = predict_fn(x_chunk)
        y.append(y_chunk)

    if x.shape[0] % chunk_size != 0:
        x_chunk = x[int(x.shape[0] / chunk_size) * chunk_size:]
        y_chunk = predict_fn(x_chunk)
        y.append(y_chunk)

    return np.vstack(y)


def process_batch(batch, use_pmap=False, dtype=jnp.float32, diff_aug=None, rng=None):
    if isinstance(batch, dict):
        image = batch['image'].astype(dtype)
        label = batch['label'].astype(dtype)
    elif isinstance(batch, Sequence):
        image = batch[0].astype(dtype)
        label = batch[1].astype(dtype)
    else:
        raise ValueError('Unknown Type {}'.format(type(batch)))

    if diff_aug is not None:
        image = diff_aug(rng, image)

    if use_pmap:
        num_devices = jax.device_count()
        # Reshape images from [num_devices * batch_size, height, width, img_channels]
        # to [num_devices, batch_size, height, width, img_channels].
        # The first dimension will be mapped across devices with jax.pmap.
        image = jnp.reshape(image, (num_devices, -1) + image.shape[1:])
        label = jnp.reshape(label, (num_devices, -1) + label.shape[1:])
    return image, label


def compute_metrics(logits, labels, loss_type):
    """
    Computes the cross entropy loss and accuracy.
    Args:
        logits (tensor): Logits, shape [B, num_classes].
        labels (tensor): Labels, shape [B, num_classes] or [B,].
    Returns:
        (dict): Dictionary containing the cross entropy loss and accuracy.
    """
    loss = loss_type(logits, labels).mean()
    if labels.ndim == 2:
        labels = labels.argmax(1)
    accuracy = pred_acurracy(logits, labels).mean()
    top5accuracy = top5_accuracy(logits, labels).mean()
    metrics = {'loss': loss, 'accuracy': accuracy, 'top5accuracy': top5accuracy}
    return metrics

def cross_entropy_loss2(logits, labels):
    """
    Args:
        logits (tensor): Logits, shape [B, num_classes].
        labels (tensor): Labels, shape [B, num_classes].
    Returns:
        (tensor): Cross entropy loss, shape [].
    """
    labels = jax.nn.one_hot(labels.argmax(-1), num_classes=labels.shape[-1])
    return -jnp.sum(labels * nn.log_softmax(logits, axis=-1).mean(axis=0), axis=-1)

def soft_cross_entropy_loss2(logits, labels):
    """
    Args:
        logits (tensor): Logits, shape [B, num_classes].
        labels (tensor): Labels, shape [B, num_classes].
    Returns:
        (tensor): Cross entropy loss, shape [].
    """
    labels = nn.softmax(labels)
    return -jnp.sum(labels * nn.log_softmax(logits, axis=-1).mean(axis=0), axis=-1)

def compute_metrics2(logits, labels, loss_type):
    """
    Computes the cross entropy loss and accuracy.
    Args:
        logits (tensor): Logits, shape [B, num_classes].
        labels (tensor): Labels, shape [B, num_classes] or [B,].
    Returns:
        (dict): Dictionary containing the cross entropy loss and accuracy.
    """
    if loss_type == cross_entropy_loss:
        loss = cross_entropy_loss2(logits, labels).mean()
    elif loss_type == soft_cross_entropy_loss:
        loss = soft_cross_entropy_loss2(logits, labels).mean()
    if labels.ndim == 2:
        labels = labels.argmax(1)
    accuracy = pred_acurracy(nn.log_softmax(logits, axis=-1).mean(axis=0), labels).mean()
    top5accuracy = top5_accuracy(nn.log_softmax(logits, axis=-1).mean(axis=0), labels).mean()

    metrics = {'loss': loss, 'accuracy': accuracy, 'top5accuracy': top5accuracy}
    return metrics

def initialized(key, img_size, img_channels, model, has_bn=False):
    """Initialize the model"""
    input_shape = (1, img_size, img_size, img_channels)

    @jax.jit
    def init(*args):
        return model.init(*args)

    key1, key2 = jax.random.split(key)
    variables = init({'params': key1, 'dropout': key2}, jnp.ones(input_shape, model.dtype))

    if has_bn:
        return variables['params'], variables['batch_stats']
    else:
        return variables['params']


def create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs, warmup_epochs):
    """Create learning rate schedule."""
    warmup_fn = optax.linear_schedule(
        init_value=0., end_value=base_learning_rate,
        transition_steps=warmup_epochs * steps_per_epoch)
    cosine_epochs = max(num_epochs - warmup_epochs, 1)
    cosine_fn = optax.cosine_decay_schedule(
        init_value=base_learning_rate,
        decay_steps=cosine_epochs * steps_per_epoch)
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[warmup_epochs * steps_per_epoch])
    return schedule_fn

def create_train_state2(rng, config, model, learning_rate_fn, has_bn=False, params=None, batch_stats=None, mask=None):
    """Create initial training state."""
    if config.optimizer == 'sgd':
        tx = optax.sgd(learning_rate=learning_rate_fn,
                       momentum=config.momentum,
                       nesterov=True)
    elif config.optimizer == 'adam':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adabelief':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adamw':
        tx = optax.adamw(learning_rate=learning_rate_fn, weight_decay=config.weight_decay, mask=mask)
    elif config.optimizer == 'lamb':
        tx = optax.lamb(learning_rate=learning_rate_fn, weight_decay=config.weight_decay)
    elif config.optimizer == 'ivon':
        tx = ivon(learning_rate=learning_rate_fn, ess=50000, hess_init=3000.0, beta1=config.beta_1, beta2=config.beta_2, weight_decay=config.weight_decay)
    else:
        raise ValueError('Unknown Optimizer Type {}'.format(config.optimizer))

    if has_bn:
        if params is None:
            params, batch_stats = initialized(rng, config.img_size, config.img_channels, model, has_bn=True)

        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  batch_stats=batch_stats,
                                  ema_hidden_batch=batch_stats,
                                  ema_average_batch=batch_stats,
                                  epoch=0, best_val_acc=0.0)
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate

    else:
        if params is None:
            params = initialized(
                rng, config.img_size, config.img_channels, model, has_bn=False)
        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  epoch=0, best_val_acc=0.0, )
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate
    return state

def create_train_state4(rng, config, model, learning_rate_fn, has_bn=False, params=None, batch_stats=None, mask=None):
    """Create initial training state."""
    if config.optimizer == 'sgd':
        tx = optax.sgd(learning_rate=learning_rate_fn,
                       momentum=config.momentum,
                       nesterov=True)
    elif config.optimizer == 'adam':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adabelief':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adamw':
        tx = optax.adamw(learning_rate=learning_rate_fn, weight_decay=config.weight_decay, mask=mask)
    elif config.optimizer == 'lamb':
        tx = optax.lamb(learning_rate=learning_rate_fn, weight_decay=config.weight_decay)
    elif config.optimizer == 'ivon':
        tx = ivon(learning_rate=learning_rate_fn, ess=50000, hess_init=1000.0, beta1=config.beta_1, beta2=config.beta_2, weight_decay=config.weight_decay)
    else:
        raise ValueError('Unknown Optimizer Type {}'.format(config.optimizer))

    if has_bn:
        if params is None:
            params, batch_stats = initialized(rng, config.img_size, config.img_channels, model, has_bn=True)

        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  batch_stats=batch_stats,
                                  ema_hidden_batch=batch_stats,
                                  ema_average_batch=batch_stats,
                                  epoch=0, best_val_acc=0.0)
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate

    else:
        if params is None:
            params = initialized(
                rng, config.img_size, config.img_channels, model, has_bn=False)
        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  epoch=0, best_val_acc=0.0, )
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate
    return state

def create_train_state3(rng, config, model, learning_rate_fn, has_bn=False, params=None, batch_stats=None, mask=None):
    """Create initial training state."""
    if config.optimizer == 'sgd':
        tx = optax.sgd(learning_rate=learning_rate_fn,
                       momentum=config.momentum,
                       nesterov=True)
    elif config.optimizer == 'adam':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adabelief':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adamw':
        tx = optax.adamw(learning_rate=learning_rate_fn, weight_decay=config.weight_decay, mask=mask)
    elif config.optimizer == 'lamb':
        tx = optax.lamb(learning_rate=learning_rate_fn, weight_decay=config.weight_decay)
    elif config.optimizer == 'ivon':
        tx = ivon(learning_rate=learning_rate_fn, ess=50000, hess_init=5000.0, beta1=config.beta_1, beta2=config.beta_2, weight_decay=config.weight_decay)
    else:
        raise ValueError('Unknown Optimizer Type {}'.format(config.optimizer))

    if has_bn:
        if params is None:
            params, batch_stats = initialized(rng, config.img_size, config.img_channels, model, has_bn=True)

        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  batch_stats=batch_stats,
                                  ema_hidden_batch=batch_stats,
                                  ema_average_batch=batch_stats,
                                  epoch=0, best_val_acc=0.0)
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate

    else:
        if params is None:
            params = initialized(
                rng, config.img_size, config.img_channels, model, has_bn=False)
        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  epoch=0, best_val_acc=0.0, )
        if config.optimizer == 'ivon':
            optstate = state.tx.init(state.params)
            return state, optstate
    return state

def create_train_state(rng, config, model, learning_rate_fn, has_bn=False, params=None, batch_stats=None, mask=None):
    """Create initial training state."""
    if config.optimizer == 'sgd':
        tx = optax.sgd(learning_rate=learning_rate_fn,
                       momentum=config.momentum,
                       nesterov=True)
    elif config.optimizer == 'adam':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adabelief':
        tx = optax.adam(learning_rate=learning_rate_fn)
    elif config.optimizer == 'adamw':
        tx = optax.adamw(learning_rate=learning_rate_fn, weight_decay=config.weight_decay, mask=mask)
    elif config.optimizer == 'lamb':
        tx = optax.lamb(learning_rate=learning_rate_fn, weight_decay=config.weight_decay)
    elif config.optimizer == 'ivon':
        tx = ivon(learning_rate=learning_rate_fn, ess=10, hess_init=0.1, beta1=0.999, beta2=0.99999, weight_decay=config.weight_decay)
    else:
        raise ValueError('Unknown Optimizer Type {}'.format(config.optimizer))

    if has_bn:
        if params is None:
            params, batch_stats = initialized(rng, config.img_size, config.img_channels, model, has_bn=True)

        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  batch_stats=batch_stats,
                                  ema_hidden_batch=batch_stats,
                                  ema_average_batch=batch_stats,
                                  epoch=0, best_val_acc=0.0)
    else:
        if params is None:
            params = initialized(
                rng, config.img_size, config.img_channels, model, has_bn=False)
        state = TrainState.create(apply_fn=model.apply,
                                  params=params,
                                  tx=tx,
                                  ema_hidden=params,
                                  ema_average=params,
                                  epoch=0, best_val_acc=0.0, )
    return state


def train_step_lb(state, batch, rng, loss_type, l2_reg=0.0, has_feat=False, has_bn=False, use_pmap=True, gamma=1.0):
    def loss_fn(params):
        if has_bn:
            variables = {'params': params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': params}

        if has_feat:
            (logits, feat), new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng},
                                                             train=True, mutable=['batch_stats'])
        else:
            logits, new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng}, train=True,
                                                     mutable=['batch_stats'])

        loss = gamma*loss_type(logits, batch['label']).mean()
        if l2_reg > 0.0:
            weight_penalty_params = jax.tree_util.tree_leaves(params)
            weight_l2 = sum([jnp.sum(x ** 2)
                             for x in weight_penalty_params if x.ndim > 1])
            weight_penalty = l2_reg * 0.5 * weight_l2
            loss = loss + weight_penalty
        return loss, (new_model_state, logits)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
    if jax.device_count() > 1 and use_pmap:
        grads = jax.lax.pmean(grads, axis_name='batch')

    new_model_state, logits = aux[1]
    metrics = compute_metrics(logits, batch['label'], loss_type)

    if has_bn:
        new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    else:
        new_state = state.apply_gradients(grads=grads)

    return new_state, metrics


def train_step_lb2(state, m, batch, rng, loss_type, l2_reg=0.0, has_feat=False, has_bn=False, use_pmap=True, gamma=1.0):
    def loss_fn(params):

        if has_bn:
            variables = {'params': params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': params}
        
        variables['params']['Dense_0']['kernel'] = m

        if has_feat:
            (logits, feat), new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng},
                                                             train=True, mutable=['batch_stats'])
        else:
            logits, new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng}, train=True,
                                                     mutable=['batch_stats'])

        loss = gamma*loss_type(logits, batch['label']).mean()
        if l2_reg > 0.0:
            weight_penalty_params = jax.tree_util.tree_leaves(params)
            weight_l2 = sum([jnp.sum(x ** 2)
                             for x in weight_penalty_params if x.ndim > 1])
            weight_penalty = l2_reg * 0.5 * weight_l2
            loss = loss + weight_penalty
        return loss, (new_model_state, logits)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
    if jax.device_count() > 1 and use_pmap:
        grads = jax.lax.pmean(grads, axis_name='batch')

    new_model_state, logits = aux[1]
    metrics = compute_metrics(logits, batch['label'], loss_type)

    if has_bn:
        new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    else:
        new_state = state.apply_gradients(grads=grads)

    return new_state, metrics



def train_step(state, batch, rng, loss_type, l2_reg=0.0, has_feat=False, has_bn=False, use_pmap=True):
    def loss_fn(params):
        if has_bn:
            variables = {'params': params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': params}

        if has_feat:
            (logits, feat), new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng},
                                                             train=True, mutable=['batch_stats'])
        else:
            logits, new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng}, train=True,
                                                     mutable=['batch_stats'])

        loss = loss_type(logits, batch['label']).mean()
        if l2_reg > 0.0:
            weight_penalty_params = jax.tree_util.tree_leaves(params)
            weight_l2 = sum([jnp.sum(x ** 2)
                             for x in weight_penalty_params if x.ndim > 1])
            weight_penalty = l2_reg * 0.5 * weight_l2
            loss = loss + weight_penalty
        return loss, (new_model_state, logits)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    
    aux, grads = grad_fn(state.params)
    # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
    if jax.device_count() > 1 and use_pmap:
        grads = jax.lax.pmean(grads, axis_name='batch')

    new_model_state, logits = aux[1]
    metrics = compute_metrics(logits, batch['label'], loss_type)

    if has_bn:
        new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    else:
        new_state = state.apply_gradients(grads=grads)

    return new_state, metrics


def train_step2(state, optstate, batch, rng, loss_type, l2_reg=0.0, has_feat=False, has_bn=False, use_pmap=True):
    def loss_fn(params):
        if has_bn:
            variables = {'params': params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': params}

        if has_feat:
            (logits, feat), new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng},
                                                             train=True, mutable=['batch_stats'])
        else:
            logits, new_model_state = state.apply_fn(variables, batch['image'], rngs={'dropout': rng}, train=True,
                                                     mutable=['batch_stats'])

        loss = loss_type(logits, batch['label']).mean()
        if l2_reg > 0.0:
            weight_penalty_params = jax.tree_util.tree_leaves(params)
            weight_l2 = sum([jnp.sum(x ** 2)
                             for x in weight_penalty_params if x.ndim > 1])
            weight_penalty = l2_reg * 0.5 * weight_l2
            loss = loss + weight_penalty
        return loss, (new_model_state, logits)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    
    train_mcsamples = 1
    rngkey, *mc_keys = jax.random.split(jax.random.PRNGKey(0), train_mcsamples+1)    

    for key in mc_keys:
        psample, noise = state.tx.sampled_params(key, state.params, optstate)

        aux, grads = grad_fn(psample)
        optstate = state.tx.accumulate(grads, optstate, noise)
    updates, optstate = state.tx.step(optstate, state.params)
    params = optax.apply_updates(state.params, updates)

    new_state = state.replace(params=params)
    # print(state.params)
    # exit()
    # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
    # if jax.device_count() > 1 and use_pmap:
    #     grads = jax.lax.pmean(grads, axis_name='batch')

    new_model_state, logits = aux[1]
    metrics = compute_metrics(logits, batch['label'], loss_type)

    # if has_bn:
    #     new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    # else:
    #     new_state = state.apply_gradients(grads=grads)
    # print(optstate[0].ess, optstate[0].weight_decay, optstate[0].hess)
    sigma = jax.tree_util.tree_map(lambda h: jax.lax.rsqrt(optstate[0].ess * (optstate[0].weight_decay + h)), optstate[0].hess)

    return new_state, optstate, metrics, sigma['Conv_0']['kernel']


def eval_step(state, batch, loss_type, has_feat=False, has_bn=False, use_ema=False):
    if use_ema:
        if has_bn:
            variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
        else:
            variables = {'params': state.ema_average}
    else:
        if has_bn:
            variables = {'params': state.params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': state.params}

    if has_feat:
        logits, feat = state.apply_fn(variables, batch['image'], train=False, mutable=False)
    else:
        logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
    return compute_metrics(logits, batch['label'], loss_type)

def predict_m(x_proto, y_proto, gamma, rho):
    n_hat = x_proto.shape[0]
    # pTp = x_proto.T@x_proto
    # lambda_1 = gamma/n_hat*x_proto.T@y_proto
    # lambda_2 = -rho/2*jnp.identity(x_proto.shape[1])-gamma/(2*n_hat)*pTp
    k_pp = x_proto.dot(x_proto.T)
    k_pp_reg = (gamma/n_hat * k_pp + jnp.abs(rho) * jnp.eye(k_pp.shape[0]))

    m=gamma/n_hat*x_proto.T@sp.linalg.solve(k_pp_reg, y_proto, assume_a="pos")

    # V=1/2*sp.linalg.inv(-lambda_2)
    # V = n_hat / gamma *(gamma/ (rho*n_hat) * jnp.eye(x_proto.shape[1]) - (gamma/(rho*n_hat))**2*x_proto.T@sp.linalg.inv(jnp.eye(x_proto.shape[0])+(gamma/(rho*n_hat))*x_proto@x_proto.T)@x_proto)
    return m

def predict_mV(x_proto, y_proto, gamma, rho):
    n_hat = x_proto.shape[0]
    pTp = x_proto.T@x_proto
    # lambda_1 = gamma/n_hat*x_proto.T@y_proto
    lambda_2 = -rho/2*jnp.identity(x_proto.shape[1])-gamma/(2*n_hat)*pTp
    k_pp = x_proto.dot(x_proto.T)
    k_pp_reg = (gamma/n_hat * k_pp + jnp.abs(rho) * jnp.eye(k_pp.shape[0]) / k_pp.shape[0])

    m=gamma/n_hat*x_proto.T@sp.linalg.solve(k_pp_reg, y_proto, assume_a="pos")

    V=1/2*sp.linalg.inv(-lambda_2)
    # V = n_hat / gamma *(gamma/ (rho*n_hat) * jnp.eye(x_proto.shape[1]) - (gamma/(rho*n_hat))**2*x_proto.T@sp.linalg.inv(jnp.eye(x_proto.shape[0])+(gamma/(rho*n_hat))*x_proto@x_proto.T)@x_proto)
    return m, V


def eval_step_cross3(x_proto, y_proto, state, batch, loss_type, has_feat=False, has_bn=False, use_ema=False, gamma=100.0, rho=100):
    if use_ema:
        if has_bn:
            variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
        else:
            variables = {'params': state.ema_average}
    else:
        if has_bn:
            variables = {'params': state.params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': state.params}
    if has_feat:
        logits, feat = state.apply_fn(variables, batch['image'], train=False, mutable=False)
        logits_proto, feat_proto = state.apply_fn(variables, x_proto, train=False, mutable=False)
        m = predict_m(feat_proto, y_proto, gamma, rho)
        n_hat = x_proto.shape[0]
        feature_prod = feat@feat_proto.T
        VV = 1 / rho *(feat@feat.T - feature_prod@sp.linalg.inv((rho*n_hat/gamma)*jnp.eye(x_proto.shape[0])+feat_proto@feat_proto.T)@feature_prod.T)
        VV = jnp.diag(VV)
        pred = feat@m
        # VV = jnp.diag(feat@V@feat.T)
        temp = jnp.expand_dims(jnp.sqrt(1+jnp.pi/8*VV*1e-8), axis=-1)
        logits = pred/temp
    else:
        logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
    return compute_metrics(logits, batch['label'], loss_type), variables['params']['Dense_0'], m


def eval_step_cross2(x_proto, y_proto, state, batch, loss_type, has_feat=False, has_bn=False, use_ema=False, gamma=100.0, rho=1):
    if use_ema:
        if has_bn:
            variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
        else:
            variables = {'params': state.ema_average}
    else:
        if has_bn:
            variables = {'params': state.params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': state.params}
    if has_feat:
        logits, feat = state.apply_fn(variables, batch['image'], train=False, mutable=False)
        logits_proto, feat_proto = state.apply_fn(variables, x_proto, train=False, mutable=False)
        m = predict_m(feat_proto, y_proto, gamma, rho)
        n_hat = x_proto.shape[0]
        feature_prod = feat@feat_proto.T
        VV = 1 / rho *(feat@feat.T - feature_prod@sp.linalg.inv((rho*n_hat/gamma)*jnp.eye(x_proto.shape[0])+feat_proto@feat_proto.T)@feature_prod.T)
        VV = jnp.diag(VV)
        pred = feat@m
        # VV = jnp.diag(feat@V@feat.T)
        temp = jnp.expand_dims(jnp.sqrt(1+jnp.pi/8*VV*1e-8), axis=-1)
        logits = pred/temp
    else:
        logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
    return compute_metrics(logits, batch['label'], loss_type)


def eval_step_cross(x_proto, y_proto, state, batch, loss_type, has_feat=False, has_bn=False, use_ema=False, gamma=100.0, rho=1.):
    if use_ema:
        if has_bn:
            variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
        else:
            variables = {'params': state.ema_average}
    else:
        if has_bn:
            variables = {'params': state.params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': state.params}
    if has_feat:
        logits, feat = state.apply_fn(variables, batch['image'], train=False, mutable=False)
        logits_proto, feat_proto = state.apply_fn(variables, x_proto, train=False, mutable=False)
        m, V = predict_mV(feat_proto, y_proto, gamma, rho)
        pred = feat@m
        VV = jnp.diag(feat@V@feat.T)
        temp = jnp.expand_dims(jnp.sqrt(1+jnp.pi/8*VV*1e-8), axis=-1)
        logits = pred/temp
    else:
        logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
    return compute_metrics(logits, batch['label'], loss_type)

def eval_step2(state, state_list, batch, loss_type, has_feat=False, has_bn=False, use_ema=False):
    variables_list = []
    for psample in state_list:
        if use_ema:
            if has_bn:
                variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
            else:
                variables = {'params': state.ema_average}
        else:
            if has_bn:
                variables = {'params': state.params, 'batch_stats': state.batch_stats}
            else:
                variables = {'params': psample}
        variables_list.append(variables)
    logits_list = []
    for variables in variables_list:
        if has_feat:
            logits, feat = state.apply_fn(variables, batch['image'], train=False, mutable=False)
        else:
            logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
        logits_list.append(logits)
    logits_list_whole = jnp.array(logits_list)
    logits_list = jnp.array(logits_list).mean(axis=0)

    if loss_type == mean_squared_loss:
        metrics = compute_metrics(logits_list, batch['label'], loss_type)
    elif loss_type == cross_entropy_loss or loss_type == soft_cross_entropy_loss:
        metrics = compute_metrics2(logits_list_whole, batch['label'], loss_type)
    return metrics, logits_list_whole

def pred_step(state, x, has_feat=False, has_bn=False, use_ema=False):
    if use_ema:
        if has_bn:
            variables = {'params': state.ema_average, 'batch_stats': state.ema_average_batch}
        else:
            variables = {'params': state.ema_average}
    else:
        if has_bn:
            variables = {'params': state.params, 'batch_stats': state.batch_stats}
        else:
            variables = {'params': state.params}

    if has_feat:
        logits, feat = state.apply_fn(variables, x, train=False, mutable=False)
    else:
        logits = state.apply_fn(variables, x, train=False, mutable=False)

    return logits


def save_logit(output_dir, state, x_train, x_test, chunk_size=500):
    @jax.jit
    def pred_fn(x): return pred_step(state, x)

    logit_train = make_chunky_prediction(x_train, pred_fn, chunk_size=chunk_size)
    logit_test = make_chunky_prediction(x_test, pred_fn, chunk_size=chunk_size)

    with open('{}/pred_logit.npz'.format(output_dir), 'wb') as f:
        np.savez(f, train=logit_train, test=logit_test)

    return logit_train, logit_test


def load_model_state(model, ckpt_path, config, key=0):
    state = create_train_state(jax.random.PRNGKey(key), config, model, lambda x: 0.01)
    if not os.path.exists(ckpt_path):
        raise ValueError('Checkpoint path {} does not exists!'.format(ckpt_path))
    state = checkpoints.restore_checkpoint(ckpt_path, state)
    return state


def load_random_state(model, config, key=0):
    state = create_train_state(jax.random.PRNGKey(key), config, model, lambda x: 0.01)
    return state


def load_teacher_model(model, ckpt_path, config):
    state = load_model_state(model, ckpt_path, config)

    @jax.jit
    def pred_fn(x): return pred_step(state, x)

    return pred_fn


def load_logit(output_dir):
    pred_logit = np.load('{}/pred_logit.npz'.format(output_dir))
    logging.info('Load logit from {}!'.format(output_dir))
    return pred_logit['train'], pred_logit['test']
