"""Bsuite experiment entry."""
import os
import random
import time
from typing import Dict, List, Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import requests
import wandb
from absl import app, flags, logging
from acme import EnvironmentLoop
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.specs import make_environment_spec
from acme.utils.loggers import Logger

from internal.loggers import generate_experiment_name, logger_fn
from internal.notifier import DiscordNotif
from internal.tracer import PyTracer
from rosmo.agent.muzero.actor import MuZeroEvalActor
from rosmo.agent.muzero.learning import SAMPLING_METHOD, MuZeroLearner
from rosmo.agent.muzero.network import Networks
from rosmo.data.bsuite.dataset import env_loader
from rosmo.types import Array
from rosmo.util.env_loop_observer import (
  EvaluationLoop,
  ExtendedEnvLoopObserver,
  LearningStepObserver,
)

platform = jax.lib.xla_bridge.get_backend().platform
num_devices = jax.lib.xla_bridge.device_count()
logging.warn(f"Compute platform: {platform} with {num_devices} devices.")

notif_url = os.environ.get("NOTIF_WEBHOOK_URL", "")
notif = DiscordNotif(url=notif_url)

FLAGS = flags.FLAGS
flags.DEFINE_string("user", "", "Wandb user id.")
flags.DEFINE_string("project", "", "Wandb project id.")
flags.DEFINE_boolean("use_wb", True, "Enable wb logging.")
flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
flags.DEFINE_boolean("debug", True, "Debug run.")
flags.DEFINE_boolean("profile", False, "Profiling run.")
flags.DEFINE_integer("run_number", 1, "Run number of RLU dataset.")
flags.DEFINE_integer(
  "data_percentage", 100, "Percentage of data used for training.", 0, 100
)
flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
flags.DEFINE_boolean("use_bc", False, "Run behavior cloning baseline.")
flags.DEFINE_boolean(
  "use_qf", False, "Learn Q function for advantage estimation."
)
flags.DEFINE_enum(
  "improvement_op",
  None,
  ["mcts", "cmpo", "mcts_mpo"],
  "Policy improvement method.",
  required=True,
)
flags.DEFINE_enum(
  "dynamics_scaling",
  None,
  ["1", "0.5", "0.25", "0.125", "0.0625"],
  "Dynamics network channel scaling.",
)
flags.DEFINE_boolean("renormalize", False, "Renormalize state.")
flags.DEFINE_boolean("noisy_obs", False, "Use noisy obs for stochasticity.")
flags.DEFINE_float("noise_scale", 0., "Observation noise scale.")
flags.DEFINE_integer("unroll_steps", 5, "Unroll steps of the dynamics.")
flags.DEFINE_integer("batch_size", 128, "Batch size for training.")
flags.DEFINE_integer("num_simulations", 4, "Simulation budget.")
flags.DEFINE_enum("sampling", None, SAMPLING_METHOD, "How to sample actions.")
flags.DEFINE_integer("pessimism", 0, "Use pessimism for the model.")
flags.DEFINE_enum(
  "behavior", "", ["", "exp", "bin", "exp_only", "bin_only"],
  "Type of behavior loss."
)
flags.DEFINE_integer(
  "num_pessimism_samples", 1, "Contrastive sampling budget."
)
flags.DEFINE_float("pessimism_weight", 0, "Reward given for ood actions.")
flags.DEFINE_integer("search_depth", None, "MCTS search depth.")
flags.DEFINE_boolean(
  "safe_q_values", False, "Fill unvisited node's q_value as node value."
)
flags.DEFINE_float(
  "dynamics_noise", 0.0, "Inject noise into dynamics if not zero."
)
flags.DEFINE_float("behavior_coef", 0.2, "Behavior loss coefficient.")
flags.DEFINE_string("checkpoint", None, "Checkpoint to resume.")
flags.DEFINE_string("game", None, "Game name to run.", required=True)
flags.DEFINE_bool("local", False, "Local jobs.")
flags.DEFINE_string("run_set", "", "Run set name for experiment grouping.")
flags.DEFINE_integer("dynamics_size", None, "dynamic layer size")

DATA_DIR = "/mnt_central/datasets/rl_unplugged/tensorflow_datasets"

CONFIG = {
  "td_steps": 3,
  "num_bins": 20,
  "encoder_layers": [64, 64, 32],
  "dynamics_layers": [32],
  "prediction_layers": [32],
  "output_init_scale": 0.0,
  "discount_factor": 0.997**4,
  "evaluate_episodes": 10,
  "clipping_threshold": 1.0,
  "log_interval": 500,
  "learning_rate": 7e-4,
  "warmup_steps": 1_000,
  "learning_rate_decay": 0.1,
  "weight_decay": 1e-4,
  "max_grad_norm": 5.0,
  "target_update_interval": 200,
  "value_coef": 0.25,
  "policy_coef": 1.0,
  "eval_period": 1_000,
  "total_steps": 200_000,
}


# ===== Learner. ===== #
def get_learner(config, networks, data_iterator, logger) -> MuZeroLearner:
  """Get MuZero learner."""
  learner = MuZeroLearner(
    networks,
    demonstrations=data_iterator,
    config=config,
    logger=logger,
  )
  return learner


# ===== Eval Actor-Env Loop. ===== #
def get_actor_env_eval_loop(config, networks, environment, observers,
                            logger) -> Tuple[MuZeroEvalActor, EnvironmentLoop]:
  """Get actor, env and evaluation loop."""
  actor = MuZeroEvalActor(
    networks,
    config,
  )
  eval_loop = EvaluationLoop(
    environment=environment,
    actor=actor,
    logger=logger,
    should_update=False,
    observers=observers,
  )
  return actor, eval_loop


def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
  """Get environment loop observers."""
  observers = []
  learning_step_ob = LearningStepObserver()
  observers.append(learning_step_ob)
  return observers


# ===== Network ===== #
def get_networks(config, environment) -> Networks:
  """Get environment-specific networks."""
  env_spec = make_environment_spec(environment)
  logging.info(env_spec)
  action_space_size = env_spec.actions.num_values

  def _representation_fun(observations: Array) -> Array:
    network = hk.Sequential(
      [hk.Flatten(), hk.nets.MLP(config["encoder_layers"])]
    )
    state = network(observations)
    return state

  representation = hk.without_apply_rng(hk.transform(_representation_fun))

  def _transition_fun(action: Array, state: Array) -> Array:
    action = hk.one_hot(action, action_space_size)
    network = hk.nets.MLP(config["dynamics_layers"])
    sa = jnp.concatenate(
      [jnp.reshape(state, (-1, state.shape[-1])), action], axis=-1
    )
    next_state = network(sa).squeeze()
    return next_state

  transition = hk.without_apply_rng(hk.transform(_transition_fun))

  def _prediction_fun(state: Array) -> Array:
    network = hk.nets.MLP(config["prediction_layers"], activate_final=True)
    head_state = network(state)
    output_init = hk.initializers.VarianceScaling(scale=0.0)
    head_policy = hk.nets.MLP([action_space_size], w_init=output_init)
    head_value = hk.nets.MLP([config["num_bins"]], w_init=output_init)
    head_reward = hk.nets.MLP([config["num_bins"]], w_init=output_init)

    return (
      head_policy(head_state),
      head_reward(head_state),
      head_value(head_state),
    )

  prediction = hk.without_apply_rng(hk.transform(_prediction_fun))

  dummy_action = utils.add_batch_dim(
    jnp.array(env_spec.actions.generate_value())
  )

  dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))

  def _dummy_state(key):
    encoder_params = representation.init(key, dummy_obs)
    dummy_state = representation.apply(encoder_params, dummy_obs)
    return dummy_state

  return Networks(
    representation_network=networks_lib.FeedForwardNetwork(
      lambda key: representation.init(key, dummy_obs), representation.apply
    ),
    transition_network=networks_lib.FeedForwardNetwork(
      lambda key: transition.init(key, dummy_action, _dummy_state(key)),
      transition.apply,
    ),
    prediction_network=networks_lib.FeedForwardNetwork(
      lambda key: prediction.init(key, _dummy_state(key)),
      prediction.apply,
    ),
    environment_specs=env_spec,
  )


# ===== Configurations ===== #
def get_config(game_name: str) -> Dict:
  """Get experiment configurations."""
  config = CONFIG.copy()
  config["benchmark"] = "bsuite"

  config["algo"] = "OS"
  if FLAGS.use_bc:
    config["algo"] = "BC"
  elif FLAGS.use_qf:
    config["algo"] = "Q-MPO"
  elif FLAGS.improvement_op == "mcts":
    config["algo"] = "MZU"
  elif FLAGS.improvement_op == "mcts_mpo":
    config["algo"] = "MCTS-Q"

  config["run_set"] = FLAGS.run_set
  config["seed"] = FLAGS.seed
  config["dynamics_scaling"] = (
    float(FLAGS.dynamics_scaling) if FLAGS.dynamics_scaling else -1
  )
  config["renormalize"] = FLAGS.renormalize
  config["behavior_coef"] = FLAGS.behavior_coef
  config["sampling"] = FLAGS.sampling or "exact"
  config["pessimism"] = FLAGS.pessimism
  config["behavior"] = FLAGS.behavior
  config["num_pessimism_samples"] = FLAGS.num_pessimism_samples
  config["pessimism_weight"] = FLAGS.pessimism_weight
  config["game_name"] = game_name
  config["run_number"] = FLAGS.run_number
  config["ckpt_number"] = FLAGS.ckpt_number
  config["unroll_steps"] = FLAGS.unroll_steps
  config["num_simulations"] = FLAGS.num_simulations
  config["search_depth"] = FLAGS.search_depth
  config["safe_q_values"] = FLAGS.safe_q_values
  config["dynamics_noise"] = FLAGS.dynamics_noise
  config["use_bc"] = FLAGS.use_bc
  config["use_qf"] = FLAGS.use_qf
  config["noisy_obs"] = FLAGS.noisy_obs
  config["noise_scale"] = FLAGS.noise_scale
  config["improvement_op"] = FLAGS.improvement_op
  config["data_percentage"] = FLAGS.data_percentage
  config["batch_size"] = 16 if FLAGS.debug else FLAGS.batch_size
  if FLAGS.dynamics_size:
    config["dynamics_layers"] = [32, FLAGS.dynamics_size, 32]
  exp_full_name = generate_experiment_name(f"{FLAGS.exp_id}_{game_name}")
  config["exp_full_name"] = exp_full_name
  try:
    config["tpu_vm"] = requests.get(
      "http://metadata/computeMetadata/v1/instance/description",
      headers={
        "Metadata-Flavor": "Google"
      },
    ).text
  except requests.exceptions.ConnectionError:
    pass
  logging.info(f"Configs: {config}")
  return config


# ===== Misc. ===== #
def get_logger_fn(
  exp_full_name: str,
  job_name: str,
  is_eval: bool = False,
  config: Optional[Dict] = None,
) -> Logger:
  """Get logger function."""
  save_data = is_eval
  return logger_fn(
    exp_name=exp_full_name,
    label=job_name,
    save_data=save_data and not (FLAGS.debug or FLAGS.profile),
    use_tb=False,
    # use_sota=not (FLAGS.debug or FLAGS.profile),
    use_wb=not (FLAGS.debug or FLAGS.profile) and FLAGS.use_wb,
    use_sota=False,
    config=config,
  )


def main(_):
  """Main program."""
  logging.info(f"Debug mode: {FLAGS.debug}")
  random.seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)

  # ===== Profiler. ===== #
  profile_dir = "./profile"
  os.makedirs(profile_dir, exist_ok=True)
  tracer = PyTracer("./", FLAGS.exp_id, with_jax=True)

  # ===== Setup. ===== #
  cfg = get_config(FLAGS.game)

  data_dir = DATA_DIR
  if FLAGS.local:
    data_dir += "/bsuite-v2"

  env, dataloader = env_loader(
    env_name=FLAGS.game,
    dataset_dir=data_dir,
    data_percentage=cfg["data_percentage"],
    batch_size=cfg["batch_size"],
    trajectory_length=cfg["td_steps"] + cfg["unroll_steps"] + 1,
    noisy_obs=FLAGS.noisy_obs,
    noise_scale=FLAGS.noise_scale
  )
  networks = get_networks(cfg, env)

  # ===== Essentials. ===== #
  learner = get_learner(
    cfg,
    networks,
    dataloader,
    get_logger_fn(
      cfg["exp_full_name"],
      "learner",
      config=cfg,
    ),
  )
  observers = get_env_loop_observers()
  actor, eval_loop = get_actor_env_eval_loop(
    cfg,
    networks,
    env,
    observers,
    get_logger_fn(cfg["exp_full_name"], "evaluator", is_eval=True, config=cfg),
  )
  evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"]

  # ===== Restore. ===== #
  init_step = 0
  if not (FLAGS.debug or FLAGS.profile) and FLAGS.use_wb:
    wb_name = cfg["exp_full_name"]
    wb_cfg = cfg
    wandb.init(
      project=FLAGS.project,
      entity=FLAGS.user,
      name=wb_name,
      config=wb_cfg,
      sync_tensorboard=False,
    )

    notif.register(f"[Experiment started] {wb_name}")
    notif.execute()
    notif.register(f"[Experiment finished] {wb_name}")

  # ===== Training Loop. ===== #
  for i in range(init_step + 1, cfg["total_steps"]):
    learner.step()
    for ob in observers:
      ob.step()

    if FLAGS.debug or (i + 1) % cfg["eval_period"] == 0:
      actor.update_params(learner.save().params)
      eval_loop.run(evaluate_episodes)

    if FLAGS.debug:
      break

    if i == 110:
      start = time.perf_counter()
      if FLAGS.profile:
        tracer.start()
        # tf.profiler.experimental.start(profile_dir)
        logging.info("Start tracing...")
    if i == 210:
      logging.info(f"100 steps took {time.perf_counter() - start} seconds")
      if FLAGS.profile:
        tracer.stop_and_save()
        # tf.profiler.experimental.stop()
        break

  # ===== Cleanup. ===== #
  learner._logger.close()
  eval_loop._logger.close()
  del env, networks, dataloader, learner, observers, actor, eval_loop

  # ===== Notif. ===== #
  notif.execute()


if __name__ == "__main__":
  app.run(main)
