import d4rl
import gym
import jax
import optax
import wandb
from absl import app, flags
from ml_collections import config_flags
from tqdm import tqdm

from jaxrl5.agents import BCLearner, IQLLearner
from jaxrl5.data.d4rl_datasets import D4RLDataset
from jaxrl5.evaluation import evaluate
from jaxrl5.wrappers import wrap_gym
from jaxrl5.wrappers.wandb_video import WANDBVideo

flags.DEFINE_string("env_name", "halfcheetah-expert-v2", "D4rl dataset name.")
flags.DEFINE_integer("seed", 42, "Random seed for learner and evaluator.")
flags.DEFINE_integer("batch_size", 256, "Batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Training steps.")
flags.DEFINE_integer("eval_episodes", 10, "Eval episodes.")
flags.DEFINE_integer("log_interval", int(1e3), "Log interval.")
flags.DEFINE_integer("eval_interval", int(1e4), "Eval interval.")
flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.")
flags.DEFINE_float("take_top", None, "Take top N% trajectories.")
flags.DEFINE_float(
    "filter_threshold", None, "Take trajectories with returns above the threshold."
)
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
config_flags.DEFINE_config_file(
    'config',
    'examples/states/configs/bc_config.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)
FLAGS = flags.FLAGS


def main(_):
    wandb.init(project='jaxrl5_offline', group=FLAGS.env_name)
    wandb.config.update(FLAGS)

    env = gym.make(FLAGS.env_name)
    ds = D4RLDataset(env)
    env = wrap_gym(env)
    if FLAGS.save_video:
        env = WANDBVideo(env)

    if FLAGS.take_top is not None or FLAGS.filter_threshold is not None:
        ds.filter(take_top=FLAGS.take_top, threshold=FLAGS.filter_threshold)

    config_dict = dict(**FLAGS.config)
    cosine_decay = config_dict.pop("cosine_decay", False)
    if cosine_decay:
        config_dict["actor_lr"] = optax.cosine_decay_schedule(
            config_dict["actor_lr"], FLAGS.max_steps
        )
    model_cls = config_dict.pop("model_cls")
    agent = globals()[model_cls].create(
        FLAGS.seed, env.observation_space, env.action_space, **config_dict
    )

    for i in tqdm(range(FLAGS.max_steps), smoothing=0.1, disable=not FLAGS.tqdm):
        sample = ds.sample_jax(FLAGS.batch_size, keys=["observations", "actions"])
        agent, info = agent.update(sample)

        if i % FLAGS.log_interval == 0:
            info = jax.device_get(info)
            wandb.log({f"train/{k}": v for k, v in info.items()}, step=i)

        if i % FLAGS.eval_interval == 0:
            eval_info = evaluate(
                agent, env, FLAGS.eval_episodes, save_video=FLAGS.save_video
            )
            eval_info["return"] = env.get_normalized_score(eval_info["return"]) * 100.0
            wandb.log({f"eval/{k}": v for k, v in eval_info.items()}, step=i)


if __name__ == "__main__":
    app.run(main)
