import os
from absl import app, flags
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import tqdm
import gym
from model_vipo import Normalizer, NormalizerState
import algo_vipo as learner
import dynamics_vipo as vipo_dynamics
from hyper_vipo import hyperparameters, get_default_config
from util_vipo import (
    get_tuned_dataset,
    merge_batch,
    get_params_shape,
    split_dataset,
    get_termination_fn,
)
from buffer_vipo import OfflineReplayBuffer

import jaxrl_m.examples.mujoco.d4rl_utils as d4rl_utils
from jaxrl_m.wandb import setup_wandb, default_wandb_config, get_flag_dict
import wandb
from jaxrl_m.evaluation import supply_rng, evaluate, flatten, EpisodeMonitor
from jaxrl_m.dataset import Dataset
from ml_collections import config_flags
import pickle
from flax.training import checkpoints


FLAGS = flags.FLAGS
flags.DEFINE_string("env_name", "walker2d-medium-expert-v2", "Environment name.")

flags.DEFINE_string("save_dir", None, "Logging dir.")

flags.DEFINE_integer("seed", np.random.choice(1000000), "Random seed.")
flags.DEFINE_integer("eval_episodes", 10, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 10000, "Eval interval.")
flags.DEFINE_integer("save_interval", 25000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_model_steps", int(2e5), "Number of training dynamics steps.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer("start_steps", int(1e4), "Number of initial exploration steps.")

wandb_config = default_wandb_config()
wandb_config.update(
    {
        "project": "d4rl_test",
        "group": "vipo_test",
        "name": "vipo_{env_name}",
    }
)

config_flags.DEFINE_config_dict("wandb", wandb_config, lock_config=False)
config_flags.DEFINE_config_dict("config", get_default_config(), lock_config=False)

overlay = {
    # add your temp hyperparameters here
}


def main(_):

    env_name = FLAGS.env_name
    env_cfg = hyperparameters.get(env_name, {})
    FLAGS.config.update(env_cfg)
    FLAGS.config.update(overlay)

    # Create wandb logger
    setup_wandb(FLAGS.config.to_dict(), **FLAGS.wandb)

    if FLAGS.save_dir is not None:
        FLAGS.save_dir = os.path.join(
            FLAGS.save_dir,
            wandb.run.project,
            wandb.config.exp_prefix,
            wandb.config.experiment_id,
        )
        os.makedirs(FLAGS.save_dir, exist_ok=True)
        print(f"Saving config to {FLAGS.save_dir}/config.pkl")
        with open(os.path.join(FLAGS.save_dir, "config.pkl"), "wb") as f:
            pickle.dump(get_flag_dict(), f)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, dynamics_key, rollout_key, sac_key = jax.random.split(key, 4)

    env = d4rl_utils.make_env(env_name)
    dataset = d4rl_utils.get_dataset(env)
    reward_tune = FLAGS.config.reward_tune
    dataset = get_tuned_dataset(dataset, reward_tune)

    example_batch = dataset.sample(1)
    example_transition = dict(
        observations=example_batch["observations"][0],
        actions=example_batch["actions"][0],
        rewards=example_batch["rewards"][0],
        masks=example_batch["masks"][0],
        dones_float=example_batch["dones_float"][0],
        next_observations=example_batch["next_observations"][0],
    )

    batch_size = FLAGS.batch_size
    real_ratio = FLAGS.config.real_ratio
    real_batch_size = int(batch_size * real_ratio)
    fake_batch_size = batch_size - real_batch_size

    rollout_length = FLAGS.config.rollout_length
    rollout_batch_size = FLAGS.config.rollout_batch_size
    rollout_retain_epochs = FLAGS.config.rollout_retain_epochs
    fake_size = rollout_length * rollout_batch_size * rollout_retain_epochs

    real_replay_buffer = OfflineReplayBuffer.create_from_existing_dataset(dataset)
    fake_replay_buffer = OfflineReplayBuffer.create(example_transition, size=fake_size)

    #########################################
    # dynamics dataset
    #########################################
    normalizer = Normalizer()

    action_norm_state = NormalizerState(
        mean=jnp.zeros_like(example_transition["actions"]),
        std=jnp.ones_like(example_transition["actions"]),
        num_points=0,
    )
    obs_norm_state = NormalizerState(
        mean=jnp.zeros_like(example_transition["observations"]),
        std=jnp.ones_like(example_transition["observations"]),
        num_points=0,
    )

    action_norm_state = normalizer.update_stats(dataset["actions"], action_norm_state)
    obs_norm_state = normalizer.update_stats(dataset["observations"], obs_norm_state)
    train_dataset, holdout_dataset = split_dataset(dataset, FLAGS.config.holdout_ratio)

    #########################################
    # dynamics training
    #########################################
    dynamics = vipo_dynamics.create_learner(
        key=dynamics_key,
        normalizer=normalizer,
        action_norm_state=action_norm_state,
        obs_norm_state=obs_norm_state,
        example_batch=example_batch,
        termination_fn=get_termination_fn(env_name),
        **FLAGS.config,
    )

    for i in tqdm.tqdm(range(1, FLAGS.max_model_steps + 1), smoothing=0.1, dynamic_ncols=True):
        batch = train_dataset.sample(batch_size)
        dynamics, update_info = dynamics.update(batch)

        if i % FLAGS.log_interval == 0:
            train_metrics = {f"dynamics_training/{k}": v for k, v in update_info.items()}
            wandb.log(train_metrics, step=i)

        if i % FLAGS.eval_interval == 0:
            # use holdout_dataset for evaluation
            eval_info = dynamics.evaluate(holdout_dataset._dict)
            eval_metrics = {f"dynamics_evaluation/{k}": v for k, v in eval_info.items()}
            wandb.log(eval_metrics, step=i)

        if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
            checkpoints.save_checkpoint(FLAGS.save_dir, dynamics, i)

    # dynamics training finished
    # return
    # plug in your own planner in algo_vipo.py, and enable the following code

    #########################################
    # agent training
    #########################################
    agent = learner.create_learner(
        sac_key,
        example_batch["observations"],
        example_batch["actions"],
        **FLAGS.config,
    )

    global_step = FLAGS.max_model_steps

    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, dynamic_ncols=True):

        #########################################
        # collect data & rollouts
        #########################################
        if i == 1 or i % FLAGS.config.rollout_freq == 0:
            init_obs = real_replay_buffer.sample(rollout_batch_size)["observations"]
            rollout_key, subkey = jax.random.split(rollout_key)
            rollout_batch, rollout_info = agent.rollout(dynamics, init_obs, rollout_length, subkey)
            fake_replay_buffer.add_batch(rollout_batch)
            rollout_metrics = {f"rollout/{k}": v for k, v in rollout_info.items()}
            wandb.log(rollout_metrics, step=global_step + i)

        real_batch = real_replay_buffer.sample(real_batch_size)
        fake_batch = fake_replay_buffer.sample(fake_batch_size)

        #########################################
        # train & evaluate agent
        #########################################

        agent, update_info = agent.update(dynamics, real_batch, fake_batch)

        if i % FLAGS.log_interval == 0:
            train_metrics = {f"planner_training/{k}": v for k, v in update_info.items()}
            wandb.log(train_metrics, step=global_step + i)

            # about the batch
            real_batch_info = {f"batch/real_{k}": v.mean() for k, v in real_batch.items()}
            fake_batch_info = {f"batch/fake_{k}": v.mean() for k, v in fake_batch.items()}
            diff_batch_info = {f"batch/diff_{k}": v.mean() - fake_batch[k].mean() for k, v in real_batch.items()}
            wandb.log(
                {
                    **real_batch_info,
                    **fake_batch_info,
                    **diff_batch_info,
                },
                step=global_step + i,
            )

        if i % FLAGS.eval_interval == 0:
            policy_fn = partial(supply_rng(agent.sample_actions), temperature=0.0)
            eval_info = evaluate(policy_fn, env, num_episodes=FLAGS.eval_episodes)

            eval_metrics = {f"planner_evaluation/{k}": v for k, v in eval_info.items()}
            wandb.log(eval_metrics, step=global_step + i)

        if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
            checkpoints.save_checkpoint(FLAGS.save_dir, agent, i)


if __name__ == "__main__":
    app.run(main)
