# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Training of model-based RL agents.

Example invocation:

python -m tensor2tensor.rl.trainer_model_based \
    --output_dir=$HOME/t2t/rl_v1 \
    --loop_hparams_set=rlmb_base \
    --loop_hparams='num_real_env_frames=10000,epochs=3'
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import math
import os
import pickle
import pprint
import random
import time

import six

from tensor2tensor.bin import t2t_trainer  # pylint: disable=unused-import
from tensor2tensor.models.research import rl
from tensor2tensor.rl import rl_utils
from tensor2tensor.rl import trainer_model_based_params
from tensor2tensor.rl.dopamine_connector import DQNLearner  # pylint: disable=unused-import
from tensor2tensor.rl.restarter import Restarter
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf


flags = tf.flags
FLAGS = flags.FLAGS


def world_model_step_increment(hparams, is_initial_epoch):
  # if epoch in [0, 1, 4, 9, 14]:
  #   multiplier = hparams.initial_epoch_train_steps_multiplier
  # else:
  #   multiplier = 1
  # return multiplier * hparams.model_train_steps
  if is_initial_epoch:
    multiplier = hparams.initial_epoch_train_steps_multiplier
  else:
    multiplier = 1
  return multiplier * hparams.model_train_steps


def setup_directories(base_dir, subdirs):
  """Setup directories."""
  base_dir = os.path.expanduser(base_dir)
  tf.gfile.MakeDirs(base_dir)

  all_dirs = {}
  for subdir in subdirs:
    if isinstance(subdir, six.string_types):
      subdir_tuple = (subdir,)
    else:
      subdir_tuple = subdir
    dir_name = os.path.join(base_dir, *subdir_tuple)
    tf.gfile.MakeDirs(dir_name)
    all_dirs[subdir] = dir_name
  return all_dirs


def make_relative_timing_fn():
  """Make a function that logs the duration since it was made."""
  start_time = time.time()

  def format_relative_time():
    time_delta = time.time() - start_time
    return str(datetime.timedelta(seconds=time_delta))

  def log_relative_time():
    tf.logging.info("Timing: %s", format_relative_time())

  return log_relative_time


def make_log_fn(epoch, log_relative_time_fn):

  def log(msg, *args):
    msg %= args
    tf.logging.info("%s Epoch %d: %s", ">>>>>>>", epoch, msg)
    log_relative_time_fn()

  return log


def train_supervised(problem, model_name, hparams, data_dir, output_dir,epoch, 
                     train_steps, eval_steps, local_eval_frequency=None, schedule="continuous_train_and_eval"):
                     #schedule="continuous_train_and_eval"):
  """Train supervised."""
  if local_eval_frequency is None:
    local_eval_frequency = FLAGS.local_eval_frequency

  exp_fn = trainer_lib.create_experiment_fn(
      model_name, problem, data_dir, train_steps, eval_steps,
      min_eval_frequency=local_eval_frequency
  )
  run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir)
  exp = exp_fn(run_config, hparams)
  metrics = getattr(exp, schedule)(epoch)
  return metrics

def train_world_model(
    env, data_dir, output_dir, hparams, world_model_steps_num, epoch
):
  """Train the world model on problem_name."""
  # world_model_steps_num += world_model_step_increment(hparams, is_initial_epoch=(epoch == 0))
  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  model_hparams.learning_rate = model_hparams.learning_rate_constant
  if epoch > 0:
    model_hparams.learning_rate *= hparams.learning_rate_bump
  if hparams.wm_policy_param_sharing:
    model_hparams.optimizer_zero_grads = True

  restarter = Restarter("world_model_{}".format(hparams.particle_id_wm), output_dir, world_model_steps_num)
  if restarter.should_skip:
    return world_model_steps_num
  with restarter.training_loop():
    metrics = train_supervised(
        problem=env,
        model_name=hparams.generative_model,
        hparams=model_hparams,
        epoch=epoch, 
        data_dir=data_dir,
        output_dir=output_dir,
        train_steps=restarter.target_global_step,
        eval_steps=100,
        local_eval_frequency=2000
    )
  
  with open(os.path.join(output_dir, 'loss_value'), 'wb+') as f:
    pickle.dump(metrics,f)

  return world_model_steps_num


def load_metrics(event_dir, epoch):
  """Loads metrics for this epoch if they have already been written.

  This reads the entire event file but it's small with just per-epoch metrics.

  Args:
    event_dir: TODO(koz4k): Document this.
    epoch: TODO(koz4k): Document this.

  Returns:
    metrics.
  """
  metrics = {}
  for filename in tf.gfile.ListDirectory(event_dir):
    path = os.path.join(event_dir, filename)
    for event in tf.train.summary_iterator(path):
      if event.step == epoch and event.HasField("summary"):
        value = event.summary.value[0]
        metrics[value.tag] = value.simple_value
  return metrics

def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
  """Run the main training loop."""
  if report_fn:
    assert report_metric is not None

  particle_id_wm = int(hparams.particle_id_wm)
  epoch = int(hparams.epoch_number_wm)
  world_model_steps_num = int(hparams.steps_done_wm)

  # Directories
  subdirectories = [
      "data", "tmp", "world_model_{}".format(particle_id_wm), ("world_model_{}".format(particle_id_wm), "debug_videos"),
      "policy", "eval_metrics"
  ]

  directories = setup_directories(output_dir, subdirectories)
  data_dir = directories["data"]
  env = rl_utils.setup_env(
      hparams, batch_size=hparams.real_batch_size,
      max_num_noops=hparams.max_num_noops,
      rl_env_max_episode_steps=hparams.rl_env_max_episode_steps
  )
  log_relative_time = make_relative_timing_fn()

  log = make_log_fn(epoch, log_relative_time)
  # Train world model
  log("Training world model")
  env.start_new_epoch(epoch-1, data_dir) # Experience available only of previous epoch
  world_model_steps_num = train_world_model(
      env, data_dir, directories["world_model_{}".format(particle_id_wm)], hparams,
      world_model_steps_num, epoch
  )

  if hparams.eval_world_model:
    debug_video_path = os.path.join(directories["world_model_{}".format(particle_id_wm), "debug_videos"],"{}.avi".format(env.current_epoch))
    wm_metrics = rl_utils.evaluate_world_model(env, hparams, directories["world_model_{}".format(particle_id_wm)], debug_video_path)
    log("World model eval metrics:\n{}".format(pprint.pformat(wm_metrics)))
  
  with open(os.path.join(directories["world_model_{}".format(particle_id_wm)], 'reward_metrics'), 'wb+') as f :
    pickle.dump(wm_metrics,f)

  with open(os.path.join(directories["world_model_{}".format(particle_id_wm)], 'step_log'), 'wb+') as f :
    pickle.dump(world_model_steps_num,f)


def main(_):
  import os
  hp = trainer_model_based_params.create_loop_hparams()
  os.environ["CUDA_VISIBLE_DEVICES"] = hp.d_id
  assert not FLAGS.job_dir_to_evaluate
  training_loop(hp, FLAGS.output_dir)


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()
