from typing import Any
import itertools

import jax
import jax.numpy as jnp
import flax
import flax.jax_utils
import numpy as np
import optax
from orbax.checkpoint import CheckpointManager
import tqdm
import wandb

from .train_state import TrainState
from multinav.utils.utils import flatten_for_wandb


def train_step(
    batch, train_state: TrainState, rng: jax.random.PRNGKey, step: jax.Array
):
    rng, dropout_key, key_sample = jax.random.split(rng, 3)

    def loss_fn(params):
        return train_state.apply_fn(
            {"params": params, "target_params": train_state.target_params},
            batch,
            rngs={"dropout": dropout_key, "sample": key_sample},
            train=True,
            step=step,
            method="loss",
        )

    grads, info = jax.grad(loss_fn, has_aux=True)(train_state.params)

    grads = jax.lax.pmean(grads, axis_name="num_devices")
    new_target_params = jax.tree_map(
        lambda p, tp: p * 0.005 + tp * 0.995,
        train_state.params,
        train_state.target_params,
    )
    train_state = train_state.apply_gradients(
        grads=grads,
        target_params=new_target_params,
    )
    train_state = train_state.replace(
        params=jax.lax.pmean(train_state.params, axis_name="num_devices")
    )

    info = jax.lax.pmean(info, axis_name="num_devices")
    info["lr"] = train_state.opt_state.hyperparams["learning_rate"]
    info["grad_norm"] = optax.global_norm(grads)
    info["param_norm"] = optax.global_norm(train_state.params)

    return train_state, info


def do_train_loop(
    *,
    step: int,
    num_steps: int,
    checkpoint_manager: CheckpointManager,
    model: TrainState,
    sharded_rngs: jax.Array,
    train_data: Any,
    device_list: Any,
    epoch: int,
    total_epochs: int,
    log_interval: int,
    save_interval: int,
):
    pmap_train_step = jax.pmap(train_step, axis_name="num_devices", devices=device_list)

    for batch in tqdm.tqdm(
        itertools.islice(train_data, num_steps),
        dynamic_ncols=True,
        desc=f"Training Epoch {epoch + 1}/{total_epochs}",
    ):
        sharded_rngs, sharded_keys = jax.pmap(
            jax.random.split, axis_name="num_devices", out_axes=1
        )(sharded_rngs)
        model, info = pmap_train_step(
            batch,
            model,
            sharded_keys,
            flax.jax_utils.replicate(jnp.asarray(step), devices=device_list),
        )

        step += 1
        if log_interval is not None and step % log_interval == 0:
            info = jax.tree_map(np.mean, jax.device_get(info))
            data_info = {
                "time_to_goal": wandb.Histogram(np.asarray(batch["time_to_goal"])),
                "reached_goal_frac": np.mean(batch["reached_goal"]),
            }
            flattened = flatten_for_wandb({"train": info, "data": data_info})
            wandb.log(flattened, step=step)
        if (
            save_interval is not None
            and step % save_interval == 0
            and checkpoint_manager is not None
        ):
            checkpoint_manager.save(
                step, {"params": jax.device_get(flax.jax_utils.unreplicate(model.params))}
            )

    return model, step, sharded_rngs
