from typing import Any, Dict, Tuple
import itertools
import chex
import einops

import jax
import flax
import numpy as np
import optax
from orbax.checkpoint import CheckpointManager
import tqdm
import wandb
from multinav.utils.metrics import compute_best_modes

from multinav.utils.trajectory import rollout_trajectories

from .train_state import TrainState
from multinav.utils.utils import flatten_for_wandb
from multinav.model import MultiNavModel

# from multinav.utils.visualization import visualize
from multinav.visualization import Visualization


def val_step(batch, train_state: TrainState, step: jax.Array):
    return train_state.apply_fn(
        {"params": train_state.params, "target_params": train_state.params},
        batch,
        train=False,
        step=step,
        method="loss",
        rngs={"sample": jax.random.PRNGKey(0)},
    )[1]


def viz_step(
    batch, train_state: TrainState, step: jax.Array
) -> Tuple[Dict[str, float], Dict[str, Visualization]]:
    return train_state.apply_fn(
        {"params": train_state.params, "target_params": train_state.params},
        batch,
        train=False,
        step=step,
        method="eval_visualize",
        rngs={"sample": jax.random.PRNGKey(0)},
    )


def do_validation_loop(
    *,
    val_dataset_name: str,
    val_data_iter: Any,
    model: TrainState,
    model_config: MultiNavModel.Config,
    device: jax.Device,
    step: int,
):
    jit_val_step = jax.jit(val_step, device=device)
    jit_viz_step = jax.jit(viz_step, device=device)

    val_metrics = []

    # Do validation
    for val_batch in val_data_iter:
        val_batch = jax.tree_map(lambda x: x._numpy(), val_batch)
        val_info = jit_val_step(val_batch, model, step=step)
        val_metrics.append(val_info)

    if len(val_metrics) == 0:
        return

    def get_mean(*xs):
        if any(x.ndim == 0 for x in xs):
            return np.mean(np.stack(xs))
        else:
            return np.mean(np.concatenate(xs))

    val_metrics = jax.tree_map(get_mean, *val_metrics)

    # Only one viz step, for sanity's sake
    verbose, viz = jit_viz_step(val_batch, model, step=step)
    def merge_dicts(d1, d2):
        result = {}
        for k, v in itertools.chain(d1.items(), d2.items()):
            if k in result:
                result[k] = merge_dicts(result[k], v)
            else:
                result[k] = v
        return result
    val_metrics = merge_dicts(val_metrics, verbose)

    wandb.log(
        flatten_for_wandb(
            {
                "val": {k: {val_dataset_name: v} for k, v in val_metrics.items()},
                "viz": jax.tree_map(
                    lambda v: {val_dataset_name: v.visualize()},
                    viz,
                    is_leaf=lambda x: isinstance(x, Visualization),
                ),
            },
        ),
        step=step,
    )
