"""Experiment entry for CQL baseline (BSuite)."""
import os
import pickle
import time
from typing import Dict, List, Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import optax
import wandb
from absl import app, flags, logging
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.specs import make_environment_spec
from acme.utils.loggers import Logger

from internal.loggers import generate_experiment_name, logger_fn
from internal.notifier import DiscordNotif
from internal.tracer import PyTracer
from rosmo.agent.cql_discrete.actor import CQLEvalActor
from rosmo.agent.cql_discrete.learning import CQLLearner
from rosmo.agent.cql_discrete.network import CQLNetworks
from rosmo.data.bsuite.dataset import env_loader
from rosmo.types import Array
from rosmo.util.env_loop_observer import (
  EvaluationLoop,
  ExtendedEnvLoopObserver,
  LearningStepObserver,
)

platform = jax.lib.xla_bridge.get_backend().platform
num_devices = jax.lib.xla_bridge.device_count()
logging.warning(f"Compute platform: {platform} with {num_devices} devices.")

notif_url = os.environ.get("NOTIF_WEBHOOK_URL", "")
notif = DiscordNotif(url=notif_url)

# ===== Flags. ===== #
FLAGS = flags.FLAGS
flags.DEFINE_string("exp_id", None, "Experiment id.", required=True)
flags.DEFINE_boolean("debug", True, "Debug run.")
flags.DEFINE_boolean("profile", False, "Profiling run.")
flags.DEFINE_integer("run_number", 1, "Run number of RLU dataset.")
flags.DEFINE_integer(
  "data_percentage", 100, "Percentage of data used for training.", 0, 100
)
flags.DEFINE_integer("seed", int(time.time()), "Random seed.")
flags.DEFINE_float("minq_weight", 1.0, "Weight for Min-Q loss.")
flags.DEFINE_boolean("use_original", False, "Use original CQL networks.")
flags.DEFINE_string("checkpoint", None, "Checkpoint to resume.")
flags.DEFINE_string("game", None, "Game name to run.", required=True)

flags.DEFINE_string("run_set", "", "Run set name for experiment grouping.")
flags.DEFINE_bool("local", False, "Local jobs.")

DATA_DIR = "/mnt_central/datasets/rl_unplugged/tensorflow_datasets"

CONFIG = {
  "num_bins": 20,
  "encoder_layers": [64, 64, 32],
  "prediction_layers": [32],
  "output_init_scale": 0.0,
  "discount": 0.99,
  "batch_size": 128 * 6,  # 6 to match unroll_length
  "learning_rate": 1e-3,
  "adam_epsilon": 0.0003125,
  "huber_param": 1.0,
  "epsilon_eval": 0.001,
  "target_update_period": 250,
  "grad_updates_per_batch": 1,
  "log_interval": 400,
  "save_period": 10_000,
  "eval_period": 2_000,
  "evaluate_episodes": 2,
  "total_steps": 200_000,
}


# ===== Learner. ===== #
def get_learner(config, networks, data_iterator, logger) -> CQLLearner:
  """Get CQL learner."""
  learner = CQLLearner(
    networks,
    random_key=jax.random.PRNGKey(config["seed"]),
    target_update_period=config["target_update_period"],
    num_atoms=config["num_bins"],
    minq_weight=config["minq_weight"],
    huber_param=config["huber_param"],
    batch_size=config["batch_size"],
    iterator=data_iterator,
    optimizer=optax.adam(
      learning_rate=config["learning_rate"],
      eps=config["adam_epsilon"],
    ),
    grad_updates_per_batch=config["grad_updates_per_batch"],
    logger=logger,
    log_interval=config["log_interval"],
  )
  return learner


# ===== Eval Actor-Env Loop. ===== #
def get_actor_env_eval_loop(
  config, networks: CQLLearner, environment, observers, logger
) -> Tuple[CQLEvalActor, EvaluationLoop]:
  """Get actor, env and evaluation loop."""
  environment_spec = make_environment_spec(environment)
  actor = CQLEvalActor(
    networks,
    environment_spec.actions.num_values,
    config,
  )
  eval_loop = EvaluationLoop(
    environment=environment,
    actor=actor,
    logger=logger,
    should_update=False,
    observers=observers,
  )
  return actor, eval_loop


def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
  """Get environment loop observers."""
  observers = []
  learning_step_ob = LearningStepObserver()
  observers.append(learning_step_ob)
  return observers


# ===== Network ===== #
def get_networks(config, environment) -> CQLNetworks:
  """Get environment-specific networks."""
  env_spec = make_environment_spec(environment)
  logging.info(env_spec)
  num_actions = env_spec.actions.num_values
  num_atoms = config["num_bins"]

  def _representation_fun(observations: Array) -> Array:
    network = hk.Sequential(
      [hk.Flatten(), hk.nets.MLP(config["encoder_layers"])]
    )
    state = network(observations)
    return state

  def _prediction_fun(state: Array) -> Array:
    network = hk.nets.MLP(config["prediction_layers"], activate_final=True)
    head_state = network(state)
    output_init = hk.initializers.VarianceScaling(scale=0.0)
    head_qf = hk.nets.MLP([num_atoms * num_actions], w_init=output_init)

    return head_qf(head_state)

  def _q_value_fn(obs: jnp.ndarray):
    state = _representation_fun(obs)
    logits = _prediction_fun(state)
    # Reshape distribution and action dimension, since
    # rlax.quantile_q_learning expects it that way.
    bs = logits.shape[0]
    logits = jnp.reshape(
      logits,
      (
        bs,
        num_atoms,
        num_actions,
      ),
    )  # (B, a, |A|)
    q_values = jnp.mean(logits, axis=1)  # (B, |A|)
    return {"q_dist": logits, "q_value": q_values}

  q_net_fun = _q_value_fn
  qf = hk.without_apply_rng(hk.transform(q_net_fun))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
  q_network = networks_lib.FeedForwardNetwork(
    lambda key: qf.init(key, dummy_obs), qf.apply
  )

  return CQLNetworks(q_network=q_network, environment_specs=env_spec)


# ===== Configurations ===== #
def get_config(game_name: str) -> Dict:
  """Get experiment configurations."""
  config = CONFIG.copy()
  config["algo"] = "CQL"
  config["benchmark"] = "bsuite"
  config["run_set"] = FLAGS.run_set
  config["game_name"] = game_name
  config["seed"] = FLAGS.seed
  config["minq_weight"] = FLAGS.minq_weight
  config["use_original"] = FLAGS.use_original
  config["run_number"] = FLAGS.run_number
  config["ckpt_number"] = FLAGS.ckpt_number
  config["data_percentage"] = FLAGS.data_percentage
  config["batch_size"] = 16 if FLAGS.debug else config["batch_size"]
  exp_full_name = generate_experiment_name(f"{FLAGS.exp_id}_{game_name}")
  config["exp_full_name"] = exp_full_name
  logging.info(f"Configs: {config}")
  return config


# ===== Misc. ===== #
def get_logger_fn(
  exp_full_name: str,
  job_name: str,
  is_eval: bool = False,
  config: Optional[Dict] = None,
) -> Logger:
  """Get logger function."""
  save_data = is_eval
  return logger_fn(
    exp_name=exp_full_name,
    label=job_name,
    save_data=save_data and not (FLAGS.debug or FLAGS.profile),
    use_tb=False,
    # use_sota=not (FLAGS.debug or FLAGS.profile),
    use_wb=not (FLAGS.debug or FLAGS.profile),
    use_sota=False,
    config=config,
  )


def main(_):
  """Main program."""
  logging.info(f"Debug mode: {FLAGS.debug}")
  # ===== Profiler. ===== #
  profile_dir = "./profile"
  os.makedirs(profile_dir, exist_ok=True)
  tracer = PyTracer("./", FLAGS.exp_id, with_jax=True)

  # ===== Setup. ===== #
  cfg = get_config(FLAGS.game)

  data_dir = DATA_DIR
  if FLAGS.local:
    data_dir += "/bsuite-v2"

  env, dataloader = env_loader(
    env_name=FLAGS.game,
    dataset_dir=data_dir,
    data_percentage=cfg["data_percentage"],
    batch_size=cfg["batch_size"],
    trajectory_length=1,
  )
  networks = get_networks(cfg, env)

  # ===== Resume. ===== #
  _RESUME = False
  _exp_full_name = None
  if FLAGS.checkpoint is not None:
    _RESUME = True
    ckpt_fn = FLAGS.checkpoint
    assert os.path.isfile(ckpt_fn)
    _exp_full_name = os.path.split(ckpt_fn)[-2].split("/")[-1]
    logging.warn(f"Resuming from {ckpt_fn}...")

  # ===== Essentials. ===== #
  learner = get_learner(
    cfg,
    networks,
    dataloader,
    get_logger_fn(
      _exp_full_name or cfg["exp_full_name"],
      "learner",
      config=cfg,
    ),
  )
  observers = get_env_loop_observers()
  actor, eval_loop = get_actor_env_eval_loop(
    cfg,
    networks,
    env,
    observers,
    get_logger_fn(
      _exp_full_name or cfg["exp_full_name"],
      "evaluator",
      is_eval=True,
      config=cfg,
    ),
  )
  evaluate_episodes = 2 if FLAGS.debug else cfg["evaluate_episodes"]

  # ===== Restore. ===== #
  init_step = 0
  if _RESUME:
    save_path = os.path.join("./checkpoint", _exp_full_name)
    assert os.path.isdir(save_path)
    with open(ckpt_fn, "rb") as f:
      train_state = pickle.load(f)
    fn = os.path.split(ckpt_fn)[-1].strip(".pkl")
    init_step = int(fn.split("_")[-1])
    learner.restore(train_state)
    for ob in observers:
      ob.restore(init_step + 1)
  else:
    save_path = os.path.join("./checkpoint", cfg["exp_full_name"])
    os.makedirs(save_path, exist_ok=True)

  if not (FLAGS.debug or FLAGS.profile):
    wb_name = cfg["exp_full_name"]
    wb_cfg = cfg
    if _RESUME:
      wb_name = f"{_exp_full_name}_resume"

    wandb.init(
      project="",
      entity="",
      name=wb_name,
      config=wb_cfg,
      sync_tensorboard=False,
    )

    notif.register(f"[Experiment started] {wb_name}")
    notif.execute()
    notif.register(f"[Experiment finished] {wb_name}")

  # ===== Training Loop. ===== #
  for i in range(init_step + 1, cfg["total_steps"]):
    learner.step(transform=True)
    for ob in observers:
      ob.step()

    if (i + 1) % cfg["save_period"] == 0:
      with open(os.path.join(save_path, f"ckpt_{i}.pkl"), "wb") as f:
        pickle.dump(learner.save(), f)

    if FLAGS.debug or (i + 1) % cfg["eval_period"] == 0:
      actor.update_params(learner.save().params)
      eval_loop.run(evaluate_episodes)

    if FLAGS.debug:
      break

    if i == 110:
      start = time.perf_counter()
      if FLAGS.profile:
        tracer.start()
        # tf.profiler.experimental.start(profile_dir)
        logging.info("Start tracing...")
    if i == 210:
      logging.info(f"100 steps took {time.perf_counter() - start} seconds")
      if FLAGS.profile:
        tracer.stop_and_save()
        # tf.profiler.experimental.stop()
        break

  # ===== Cleanup. ===== #
  learner._logger.close()
  eval_loop._logger.close()
  del env, networks, dataloader, learner, observers, actor, eval_loop

  # ===== Notif. ===== #
  notif.execute()


if __name__ == "__main__":
  app.run(main)
