from functools import partial
from typing import NamedTuple
import jax
import haiku as hk
import optax
import jax.numpy as jnp

class GDState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState


class TrainerGD(object):
    def __init__(trainer, model, optimizer, loss_fn, cfg):
        trainer.rng, model_init_rng = jax.random.split(
            jax.random.PRNGKey(cfg["seed"]), 2
        )
        trainer.model = model
        trainer.optimizer = optimizer
        trainer.loss_fn = loss_fn
        trainer.cfg = cfg

        params = trainer.model.init(model_init_rng, loss_fn.dummy_data())

        num_params = jax.tree_util.tree_reduce(lambda s, n: n.size + s, params, 0)
        print(f"Number of parameters: {num_params / 1e3:.02f}k")

        opt_state = trainer.optimizer.init(params)

        trainer.gd_state = GDState(params=params, opt_state=opt_state)

    @partial(jax.jit, static_argnums=(0))
    def _do_step(trainer, state, rng):
        params, opt_state = state
        (loss, auxs), grads = jax.value_and_grad(trainer.loss_fn, has_aux=True)(
            params, rng
        )
        grads = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, copy=True, nan=0.0, posinf=0, neginf=0), grads)
        
        updates, opt_state = trainer.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        log_dict = dict()
        log_dict["loss"] = loss
        log_dict.update({k: v.mean() for k, v in auxs.items()})

        return GDState(params, opt_state), log_dict

    @partial(jax.jit, static_argnums=(0, 3))
    def _do_steps(trainer, gd_state, epoch_rng, num_steps):
        gd_state, log_metric = jax.lax.scan(
            trainer._do_step,
            gd_state,
            (jax.random.split(epoch_rng, num=num_steps)),
        )
        return gd_state, log_metric
    
    def train_iter(trainer, num_steps):
        trainer.rng, epoch_rng = jax.random.split(trainer.rng)

        trainer.gd_state, log_metric = trainer._do_steps(
            trainer.gd_state, epoch_rng, num_steps
        )

        return log_metric

    def get_params(trainer):
        return trainer.gd_state.params
