from flax.training.train_state import TrainState
from flax import core
from functools import partial

from typing import Any
import jax
from jax import jit
import jax.numpy as jnp
from flax.training import checkpoints


class SSLTrainState(TrainState):
    target_params: core.FrozenDict[str, Any]
    batch_stats: Any
    tg_batch_stats: Any
    direct_pred: Any = None

    def update_ema(self, current_step, total_steps):
        """Update the exponential moving average of the parameters."""
        tau_base = 0.996
        # cosine schedule for tau
        tau = 1.0 - (1.0 - tau_base) * 0.5 * (1 + jnp.cos(jnp.pi * current_step / total_steps))
        new_target_params = jax.tree_map(lambda x, y: tau * x + (1.0 - tau) * y, self.target_params, self.params)
        return self.replace(target_params=new_target_params)

    @classmethod
    def create(cls, *, apply_fn, params, target_params, tx, **kwargs):
        """Creates a new instance with `step=0` and initialized `opt_state`."""
        opt_state = tx.init(params)
        return cls(
            step=0,
            apply_fn=apply_fn,
            params=params,
            target_params=target_params,
            tx=tx,
            opt_state=opt_state,
            **kwargs,
        )

@partial(jit, static_argnums=(2,3))
def train_step(state, batch, loss_fn, parallel):
    """Train for a single step."""
    
    # compute gradients
    v_and_g_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (updates, tg_updates, lep_metrics)), grads = v_and_g_fn(state.params, state, batch)

    # training loss also includes lep losses, which are not used for training the encoder
    lep_loss = sum(lep_metrics['loss'].values())
    loss = loss - lep_loss

    # log metrics
    pretty_loss_name = loss_fn.__name__.split('_')[1]
    metrics = {f'{pretty_loss_name}_train_loss': loss}

    # log lep losses and accuracies
    for key in lep_metrics['loss'].keys():
        if 'embd' in key:
            metrics['online_lep_train_loss'] = lep_metrics['loss'][key]
            metrics['online_lep_train_acc'] = lep_metrics['acc'][key]
        elif 'proj' in key:
            metrics['online_lep_proj_train_loss'] = lep_metrics['loss'][key]
            metrics['online_lep_proj_train_acc'] = lep_metrics['acc'][key]
        else:
            raise ValueError(f'Unknown key {key} in train lep_metrics')

    if parallel:
        grads = jax.lax.pmean(grads, axis_name='device')
        metrics = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='device'), metrics)
        if 'direct_pred' in updates:
            for key in updates['direct_pred']:
                updates = updates.unfreeze()
                updates['direct_pred'][key]['run_corr'] = jax.lax.pmean(updates['direct_pred'][key]['run_corr'], axis_name='device')

    # update state
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    if tg_updates is not None:
        state = state.replace(tg_batch_stats=tg_updates['batch_stats'])
    if 'direct_pred' in updates:
        state = state.replace(direct_pred=updates['direct_pred'])
    
    return state, metrics

def save_state(state, args, epoch):
    ckpt = {'model': state}
    checkpoints.save_checkpoint(ckpt_dir=f'checkpoints/{args.model_name}', target=ckpt,
                                step=epoch, overwrite=False, keep=2)


def load_state(state, name, epoch):
    target = {'model': state}
    ckpt = checkpoints.restore_checkpoint(f'checkpoints/{name}', target=target, step=epoch)
    state = ckpt['model']
    return state
