"""Starter code from the RLPD repository https://github.com/ikostrikov/rlpd"""
#! /usr/bin/env python
import os
import pickle

import d4rl
import d4rl.gym_mujoco
import d4rl.locomotion
import dmcgym
import gym
import numpy as np
import tqdm
from absl import app, flags

try:
    from flax.training import checkpoints
except:
    print("Not loading checkpointing functionality.")
from ml_collections import config_flags

import wandb
from expo.agents import EXPOLearner
from expo.agents import SACLearner
from expo.data import ReplayBuffer
from expo.data import RoboReplayBuffer
from expo.data.d4rl_datasets import D4RLDataset

import mimicgen
from robomimic.utils.dataset import SequenceDataset
from expo.data.robomimic_datasets import (
    process_robomimic_dataset, get_mimicgen_env, get_robomimic_env, RoboD4RLDataset, 
    ENV_TO_HORIZON_MAP, MIMICGEN_ENV_TO_HORIZON_MAP, OBS_KEYS
)
import cloudpickle as pickle

try:
    from expo.data.binary_datasets import BinaryDataset
except:
    print("not importing binary dataset")
from expo.evaluation import evaluate, evaluate_diffusion, evaluate_robo
from expo.wrappers import wrap_gym

FLAGS = flags.FLAGS

flags.DEFINE_string("project_name", "expo", "wandb project name.")
flags.DEFINE_string("env_name", "halfcheetah-expert-v2", "D4rl dataset name.")
flags.DEFINE_float("offline_ratio", 0.5, "Offline ratio.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 100, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 10000, "Eval interval.")
flags.DEFINE_integer("offline_eval_interval", 50000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer(
    "start_training", int(1e4), "Number of training steps to start training."
)
flags.DEFINE_integer(
    "num_data", 0, "Number of training steps to start training."
)
flags.DEFINE_string("dataset_dir", "halfcheetah-expert-v2", "D4rl dataset name.")
flags.DEFINE_integer("pretrain_steps", 0, "Number of offline updates.")
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
flags.DEFINE_boolean("save_video", False, "Save videos during evaluation.")
flags.DEFINE_boolean("checkpoint_model", False, "Save agent checkpoint on evaluation.")
flags.DEFINE_boolean(
    "checkpoint_buffer", False, "Save agent replay buffer on evaluation."
)
flags.DEFINE_integer("utd_ratio", 1, "Update to data ratio.")
flags.DEFINE_boolean(
    "binary_include_bc", True, "Whether to include BC data in the binary datasets."
)
flags.DEFINE_boolean(
    "pretrain_edit", False, "Whether to pretrain edit policy."
)
flags.DEFINE_boolean(
    "pretrain_q", False, "Whether to pretrain Q-function."
)

config_flags.DEFINE_config_file(
    "config",
    "configs/sac_config.py",
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)


def combine(one_dict, other_dict):
    combined = {}

    for k, v in one_dict.items():
        if isinstance(v, dict):
            combined[k] = combine(v, other_dict[k])
        else:
            tmp = np.empty(
                (v.shape[0] + other_dict[k].shape[0], *v.shape[1:]), dtype=v.dtype
            )
            tmp[0::2] = v
            tmp[1::2] = other_dict[k]
            combined[k] = tmp


    return combined


def main(_):
    assert FLAGS.offline_ratio >= 0.0 and FLAGS.offline_ratio <= 1.0

    wandb.init(project=FLAGS.project_name)
    wandb.config.update(FLAGS)

    exp_prefix = f"s{FLAGS.seed}_{FLAGS.pretrain_steps}pretrain"
    if hasattr(FLAGS.config, "critic_layer_norm") and FLAGS.config.critic_layer_norm:
        exp_prefix += "_LN"

    log_dir = os.path.join(FLAGS.log_dir, exp_prefix)

    if FLAGS.checkpoint_model:
        chkpt_dir = os.path.join(log_dir, "checkpoints")
        os.makedirs(chkpt_dir, exist_ok=True)

    if FLAGS.checkpoint_buffer:
        buffer_dir = os.path.join(log_dir, "buffers")
        os.makedirs(buffer_dir, exist_ok=True)


    if FLAGS.env_name in ENV_TO_HORIZON_MAP:

        dataset_path = f'./robomimic/datasets/{FLAGS.env_name}/ph/low_dim_v141.hdf5'
        if FLAGS.dataset_dir != '' and FLAGS.dataset_dir != 'mh'and FLAGS.dataset_dir != 'ph':
            with open(FLAGS.dataset_dir, 'rb') as handle:
                dataset = pickle.load(handle)
            
            dataset['rewards'] = dataset['rewards'].squeeze()
            dataset['terminals'] = dataset['terminals'].squeeze()
        elif FLAGS.dataset_dir == 'ph':
            seq_dataset = SequenceDataset(hdf5_path=f'./robomimic/datasets/{FLAGS.env_name}/ph/low_dim_v141.hdf5',
                                        obs_keys=OBS_KEYS,
                                        dataset_keys=("actions", "rewards", "dones"),
                                        hdf5_cache_mode="all",
                                        load_next_obs=True)
            dataset = process_robomimic_dataset(seq_dataset)
        else:
            seq_dataset = SequenceDataset(hdf5_path=f'./robomimic/datasets/{FLAGS.env_name}/mh/low_dim_v141.hdf5',
                                        obs_keys=OBS_KEYS,
                                        dataset_keys=("actions", "rewards", "dones"),
                                        hdf5_cache_mode="all",
                                        load_next_obs=True)
            dataset = process_robomimic_dataset(seq_dataset)
        ds = RoboD4RLDataset(env=None, num_data=FLAGS.num_data, custom_dataset=dataset)


        example_observation = ds.dataset_dict['observations'][0][np.newaxis]
        example_action = ds.dataset_dict['actions'][0][np.newaxis]
        env = get_robomimic_env(dataset_path, example_action, FLAGS.env_name)
        eval_env = get_robomimic_env(dataset_path, example_action, FLAGS.env_name)
        max_traj_len = ENV_TO_HORIZON_MAP[FLAGS.env_name]

    else:

        dataset_path = f'./mimicgen/datasets/source/{FLAGS.env_name}.hdf5'
        dataset_dir = f'./mimicgen/datasets/{FLAGS.env_name}/dataset.pkl'
        with open(dataset_dir, 'rb') as handle:
            dataset = pickle.load(handle)
        
        dataset['rewards'] = dataset['rewards'].squeeze()
        dataset['terminals'] = dataset['terminals'].squeeze()


        ds = RoboD4RLDataset(env=None, custom_dataset=dataset, num_data=FLAGS.num_data)
        example_observation = ds.dataset_dict['observations'][0][np.newaxis]
        example_action = ds.dataset_dict['actions'][0][np.newaxis]


        env = get_mimicgen_env(dataset_path, example_action, FLAGS.env_name)
        eval_env = get_mimicgen_env(dataset_path, example_action, FLAGS.env_name)
        max_traj_len = MIMICGEN_ENV_TO_HORIZON_MAP[FLAGS.env_name]


    kwargs = dict(FLAGS.config)
    model_cls = kwargs.pop("model_cls")
    agent = globals()[model_cls].create(
        FLAGS.seed, example_observation.squeeze(), example_action.squeeze(), **kwargs
    )

    replay_buffer = RoboReplayBuffer(
        example_observation.squeeze(), example_action.squeeze(), FLAGS.max_steps
    )
    replay_buffer.seed(FLAGS.seed)

    for i in tqdm.tqdm(
        range(0, FLAGS.pretrain_steps), smoothing=0.1, disable=not FLAGS.tqdm
    ):
        offline_batch = ds.sample(FLAGS.batch_size * FLAGS.utd_ratio)
        batch = {}
        for k, v in offline_batch.items():
            batch[k] = v
            if "antmaze" in FLAGS.env_name and k == "rewards":
                batch[k] -= 1

        agent, update_info = agent.update_offline(batch, FLAGS.utd_ratio, FLAGS.pretrain_q, FLAGS.pretrain_edit)

        if i % FLAGS.log_interval == 0:
            for k, v in update_info.items():
                wandb.log({f"offline-training/{k}": v}, step=i)

        if i % FLAGS.offline_eval_interval == 0:
            eval_info = evaluate_robo(agent, eval_env, max_traj_len=max_traj_len, num_episodes=FLAGS.eval_episodes)


            for k, v in eval_info.items():
                wandb.log({f"offline-evaluation/{k}": v}, step=i)

    observation, done = env.reset(), False
    log_returns = 0
    for i in tqdm.tqdm(
        range(0, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
    ):
        if i < FLAGS.start_training:
            action = np.random.uniform(-1, 1, size=(example_action.shape[1], ))
        else:
            action, agent = agent.sample_actions(observation)
        next_observation, reward, done, info = env.step(action)

        if not done or "TimeLimit.truncated" in info:
            mask = 1.0
        else:
            mask = 0.0

        replay_buffer.insert(
            dict(
                observations=observation,
                actions=action,
                rewards=reward,
                masks=mask,
                dones=done,
                next_observations=next_observation,
            )
        )
        log_returns += reward
        observation = next_observation

        if done:
            observation, done = env.reset(), False

            for k, v in info["episode"].items():
                decode = {"r": "return", "l": "length"}
                wandb.log({f"training/{k}": v}, step=i + FLAGS.pretrain_steps)

        if i >= FLAGS.start_training:
            online_batch = replay_buffer.sample(
                int(FLAGS.batch_size * FLAGS.utd_ratio * (1 - FLAGS.offline_ratio))
            )
            offline_batch = ds.sample(
                int(FLAGS.batch_size * FLAGS.utd_ratio * FLAGS.offline_ratio)
            )

            batch = combine(offline_batch, online_batch)

            agent, update_info = agent.update(batch, FLAGS.utd_ratio)

            if i % FLAGS.log_interval == 0:
                for k, v in update_info.items():
                    wandb.log({f"training/{k}": v}, step=i + FLAGS.pretrain_steps)

        if i % FLAGS.eval_interval == 0:

            eval_info = evaluate_robo(
                agent,
                eval_env,
                max_traj_len=max_traj_len, 
                num_episodes=FLAGS.eval_episodes,
                save_video=FLAGS.save_video,
            )


            for k, v in eval_info.items():
                wandb.log({f"evaluation/{k}": v}, step=i + FLAGS.pretrain_steps)

            if FLAGS.checkpoint_model:
                try:
                    checkpoints.save_checkpoint(
                        chkpt_dir, agent, step=i, keep=20, overwrite=True
                    )
                except:
                    print("Could not save model checkpoint.")

            if FLAGS.checkpoint_buffer:
                try:
                    with open(os.path.join(buffer_dir, f"buffer"), "wb") as f:
                        pickle.dump(replay_buffer, f, pickle.HIGHEST_PROTOCOL)
                except:
                    print("Could not save agent buffer.")


if __name__ == "__main__":
    app.run(main)
