import os
from typing import Tuple

import gym
import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter

import wrappers
from dataset_utils import Batch, D4RLDataset, ReplayBuffer, split_into_trajectories
from evaluation import evaluate
from learner import Learner

FLAGS = flags.FLAGS

flags.DEFINE_string("env_name", "halfcheetah-expert-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./tmp/", "Tensorboard logging dir.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 100, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 100000, "Eval interval.")
flags.DEFINE_integer("pretraining_eval_interval", 100000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer("num_pretraining_steps", int(1e6), "Number of pretraining steps.")
flags.DEFINE_integer(
    "replay_buffer_size", 2000000, "Replay buffer size (=max_steps if unspecified)."
)
flags.DEFINE_integer(
    "init_dataset_size", None, "Offline data size (uses all data if unspecified)."
)
flags.DEFINE_boolean("tqdm", True, "Use tqdm progress bar.")
config_flags.DEFINE_config_file(
    "config",
    "configs/antmaze_finetune_config.py",
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)
flags.DEFINE_integer("num_actions_to_sample", 1, "Number of actions to sample")
flags.DEFINE_float("fixed_action_noise", -1, "Fixed action noise")
flags.DEFINE_integer("critic_ensemble_size", 2, "Number of critics in the ensemble")
flags.DEFINE_integer("critic_training_iterations", 1, "Number of training iterations")
flags.DEFINE_integer(
    "consecutive_training_iters", 1, "Number of consecutive training iterations"
)
flags.DEFINE_integer(
    "consecutive_env_steps", 1, "Number of consecutive environment collection steps"
)
flags.DEFINE_float("optimism_parameter", 0, "Optimism parameter")


def normalize(dataset):

    trajs = split_into_trajectories(
        dataset.observations,
        dataset.actions,
        dataset.rewards,
        dataset.masks,
        dataset.dones_float,
        dataset.next_observations,
    )

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    trajs.sort(key=compute_returns)

    dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
    dataset.rewards *= 1000.0


def make_env_and_dataset(env_name: str, seed: int) -> Tuple[gym.Env, D4RLDataset]:
    env = gym.make(env_name)

    env = wrappers.EpisodeMonitor(env)
    env = wrappers.SinglePrecision(env)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    dataset = D4RLDataset(env)

    if "antmaze" in FLAGS.env_name:
        # dataset.rewards -= 1.0
        pass  # normalized in the batch instead
        # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
        # but I found no difference between (x - 0.5) * 4 and x - 1.0
    elif (
        "halfcheetah" in FLAGS.env_name
        or "walker2d" in FLAGS.env_name
        or "hopper" in FLAGS.env_name
    ):
        normalize(dataset)

    return env, dataset


def main(_):
    summary_writer = SummaryWriter(
        os.path.join(FLAGS.save_dir, "tb", str(FLAGS.seed)), write_to_disk=True
    )
    os.makedirs(FLAGS.save_dir, exist_ok=True)

    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)

    action_dim = env.action_space.shape[0]
    replay_buffer = ReplayBuffer(
        env.observation_space, action_dim, FLAGS.replay_buffer_size or FLAGS.max_steps
    )
    replay_buffer.initialize_with_dataset(dataset, FLAGS.init_dataset_size)

    kwargs = dict(FLAGS.config)
    agent = Learner(
        FLAGS.seed,
        env.observation_space.sample()[np.newaxis],
        env.action_space.sample()[np.newaxis],
        **kwargs,
        critic_ensemble_size=FLAGS.critic_ensemble_size,
        critic_training_iterations=FLAGS.critic_training_iterations,
    )

    eval_returns = []
    observation, done = env.reset(), False

    # Use negative indices for pretraining steps.
    with tqdm.tqdm(
        total=FLAGS.num_pretraining_steps + FLAGS.max_steps,
        # range(1 - FLAGS.num_pretraining_steps, FLAGS.max_steps + 1),
        smoothing=0.1,
        disable=not FLAGS.tqdm,
    ) as pbar:
        # for i in pbar:
        i = -FLAGS.num_pretraining_steps
        while i <= FLAGS.max_steps:
            if i >= 1:
                # Collect samples
                for _ in range(FLAGS.consecutive_env_steps):
                    action, action_info = agent.sample_actions(
                        observation,
                        num_actions_to_sample=FLAGS.num_actions_to_sample,
                        fixed_action_noise=FLAGS.fixed_action_noise,
                        optimism_parameter=FLAGS.optimism_parameter,
                    )
                    action = np.clip(action, -1, 1)
                    next_observation, reward, done, info = env.step(action)

                    if not done or "TimeLimit.truncated" in info:
                        mask = 1.0
                    else:
                        mask = 0.0

                    replay_buffer.insert(
                        observation, action, reward, mask, float(done), next_observation
                    )
                    observation = next_observation

                    if done:
                        observation, done = env.reset(), False
                        for k, v in info["episode"].items():
                            summary_writer.add_scalar(
                                f"training/{k}", v, info["total"]["timesteps"]
                            )
            else:
                info = {}
                info["total"] = {"timesteps": i}
                action_info = {}
            # Training
            for _ in range(FLAGS.consecutive_training_iters):
                batch = replay_buffer.sample(FLAGS.batch_size)
                if "antmaze" in FLAGS.env_name:
                    batch = Batch(
                        observations=batch.observations,
                        actions=batch.actions,
                        rewards=batch.rewards - 1,
                        masks=batch.masks,
                        next_observations=batch.next_observations,
                    )
                update_info = agent.update(batch, step=i)
                update_info.update(action_info)

            if i % FLAGS.log_interval == 0:
                for k, v in update_info.items():
                    if v.ndim == 0:
                        summary_writer.add_scalar(f"training/{k}", v, i)
                    else:
                        summary_writer.add_histogram(f"training/{k}", v, i)
                summary_writer.flush()

            if (i >= 1 and i % FLAGS.eval_interval == 0) or (
                i < 1 and i % FLAGS.pretraining_eval_interval == 0
            ):
                eval_stats = evaluate(agent, env, FLAGS.eval_episodes, 1, 0, 0)
                eval_stats_ours = evaluate(
                    agent,
                    env,
                    FLAGS.eval_episodes,
                    FLAGS.num_actions_to_sample,
                    FLAGS.fixed_action_noise,
                    0,
                )

                for k, v in eval_stats.items():
                    summary_writer.add_scalar(f"evaluation/average_{k}s", v, i)
                for k, v in eval_stats_ours.items():
                    summary_writer.add_scalar(f"evaluation_ours/average_{k}s", v, i)
                summary_writer.flush()

                eval_returns.append((i, eval_stats["return"]))
                np.savetxt(
                    os.path.join(FLAGS.save_dir, f"{FLAGS.seed}.txt"),
                    eval_returns,
                    fmt=["%d", "%.1f"],
                )
            i += FLAGS.consecutive_env_steps
            pbar.update(FLAGS.consecutive_env_steps)
    pbar.close()


if __name__ == "__main__":
    app.run(main)
