import jax
from jax import vmap
import optax
from model_vipo import *
from jaxrl_m.typing import *
from jaxrl_m.common import TrainState, target_update
from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import jaxrl_m.examples.mujoco.d4rl_utils as d4rl_utils  # a local self defined / reused module
import wandb
from pprint import pprint
from absl import app, flags
from util_vipo import (
    l1_loss,
    l2_loss,
    var_loss,
    msew_loss,
    nll_loss,
    sample_from_norm,
)


def main(_):
    # dataset
    env_name = "halfcheetah-expert-v2"
    env = d4rl_utils.make_env(env_name)
    dataset = d4rl_utils.get_dataset(env)

    # Create wandb logger
    wandb_config = {
        "project": "vipo_test",
        "group": "networks_test",
        "name": f"vipo_ensembles_{env_name}",
    }
    setup_wandb(default_wandb_config(), **wandb_config)

    if False:
        for k in dataset:
            print(k, jax.tree.map(lambda x: x.shape, dataset[k]))

    # model def
    one_sample = dataset.sample(1)
    for k in one_sample:
        print(k, jax.tree.map(lambda x: x.shape, one_sample[k]))
    action_dim = one_sample["actions"].shape[-1]
    reward_dim = one_sample["rewards"].shape[-1]
    state_dim = one_sample["observations"].shape[-1]
    ensemble_dynamics_def = EnsembledDynamics(
        ensemble_size=3,
        hidden_dims=[256, 256],
        obs_dim=state_dim,
        action_dim=action_dim,
    )

    # train
    key = jax.random.PRNGKey(0)
    key, *subkey = jax.random.split(key, 3)
    (out, params) = ensemble_dynamics_def.init_with_output(
        subkey[0],
        one_sample["observations"],
        one_sample["actions"],
    )
    means, log_stds = out
    pprint(jax.tree.map(lambda x: x.shape, params))
    # pprint(out)
    print(means.shape)
    print(log_stds.shape)

    params = params["params"]

    dynamics = TrainState.create(
        ensemble_dynamics_def,
        tx=optax.adamw(3e-4),
        params=params,
    )

    @jax.jit
    def update(dynamics: TrainState, batch: Batch) -> InfoDict:
        def loss_fn(params):
            (means, logstds) = dynamics(
                batch["observations"],
                batch["actions"],
                params=params,
            )
            gt = jnp.concatenate(
                [batch["next_observations"], batch["rewards"][..., None]], axis=-1
            )
            dyn_nll_loss = vmap(nll_loss, in_axes=(0, 0, None))(means, logstds, gt)
            dyn_mean_loss = vmap(msew_loss, in_axes=(0, 0, None))(means, logstds, gt)
            dyn_var_loss = vmap(var_loss)(logstds)
            dyn_loss = dyn_mean_loss.mean() + dyn_var_loss.mean()
            return dyn_nll_loss.mean(), {
                "dyn_loss": dyn_loss,
                "dyn_nll_loss": dyn_nll_loss.mean(),
                "dyn_mean_loss": dyn_mean_loss.mean(),
                "dyn_var_loss": dyn_var_loss.mean(),
            }

        new_dynamics, info = dynamics.apply_loss_fn(loss_fn=loss_fn, has_aux=True)
        return new_dynamics, info

    eval_table = wandb.Table(
        columns=[
            "step",
            "type",
            "l1_loss",
            "l2_loss",
        ]
        + [f"dim_{i}" for i in range(state_dim + reward_dim)],
    )

    # train
    import tqdm

    for i in tqdm.trange(int(1e5)):
        batch = dataset.sample(256)
        dynamics, info = update(dynamics, batch)

        if (i + 1) % int(5e3) == 0:
            wandb.log(info, step=i + 1)
            test_data = dataset.sample(1)
            gt = jnp.concatenate(
                [test_data["observations"], test_data["rewards"][..., None]],
                axis=-1,
            )
            key, subkey = jax.random.split(key)
            pred_means, pred_logstds = dynamics(
                test_data["observations"],
                test_data["actions"],
            )
            pred = sample_from_norm(
                pred_means.mean(axis=0),
                pred_logstds.mean(axis=0),
                subkey,
            )
            l1 = l1_loss(pred, gt)
            l2 = l2_loss(pred, gt)
            ml = msew_loss(pred_means.mean(axis=0), pred_logstds.mean(axis=0), gt)
            vl = var_loss(pred_logstds.mean(axis=0))
            nl = nll_loss(pred_means.mean(axis=0), pred_logstds.mean(axis=0), gt)
            eval_table.add_data(
                i + 1,
                "gt",
                l1,
                l2,
                *gt[0].tolist(),
            )
            eval_table.add_data(
                i + 1,
                "pred",
                l1,
                l2,
                *pred[0].tolist(),
            )

    wandb.log({"eval": eval_table})


if __name__ == "__main__":
    app.run(main)
