# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Training of model-based RL agents.

Example invocation:

python -m tensor2tensor.rl.trainer_model_based \
    --output_dir=$HOME/t2t/rl_v1 \
    --loop_hparams_set=rlmb_base \
    --loop_hparams='num_real_env_frames=10000,epochs=3'
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import math
import os
import pickle
import pprint
import random
import time

import six

import numpy as np

from tensor2tensor.bin import t2t_trainer  # pylint: disable=unused-import
from tensor2tensor.models.research import rl
from tensor2tensor.rl import rl_utils
from tensor2tensor.rl import trainer_model_based_params
from tensor2tensor.rl.dopamine_connector import DQNLearner  # pylint: disable=unused-import
from tensor2tensor.rl.restarter import Restarter
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf


flags = tf.flags
FLAGS = flags.FLAGS


def real_env_step_increment(hparams, epoch):
  """Real env step increment."""
  iters = 2 if epoch == -1 else 1
  return iters * int(math.ceil(
      hparams.num_real_env_frames / hparams.epochs
  ))


def world_model_step_increment(hparams, is_initial_epoch):
  # if epoch in [0, 1, 4, 9, 14]:
  #   multiplier = hparams.initial_epoch_train_steps_multiplier
  # else:
  #   multiplier = 1
  # return multiplier * hparams.model_train_steps
  if is_initial_epoch:
    multiplier = hparams.initial_epoch_train_steps_multiplier
  else:
    multiplier = 1
  return multiplier * hparams.model_train_steps



def setup_directories(base_dir, subdirs):
  """Setup directories."""
  base_dir = os.path.expanduser(base_dir)
  tf.gfile.MakeDirs(base_dir)

  all_dirs = {}
  for subdir in subdirs:
    if isinstance(subdir, six.string_types):
      subdir_tuple = (subdir,)
    else:
      subdir_tuple = subdir
    dir_name = os.path.join(base_dir, *subdir_tuple)
    tf.gfile.MakeDirs(dir_name)
    all_dirs[subdir] = dir_name
  return all_dirs


def make_relative_timing_fn():
  """Make a function that logs the duration since it was made."""
  start_time = time.time()

  def format_relative_time():
    time_delta = time.time() - start_time
    return str(datetime.timedelta(seconds=time_delta))

  def log_relative_time():
    tf.logging.info("Timing: %s", format_relative_time())

  return log_relative_time


def make_log_fn(epoch, log_relative_time_fn):

  def log(msg, *args):
    msg %= args
    tf.logging.info("%s Epoch %d: %s", ">>>>>>>", epoch, msg)
    log_relative_time_fn()

  return log


def random_rollout_subsequences(rollouts, num_subsequences, subsequence_length):
  """Chooses a random frame sequence of given length from a set of rollouts."""
  def choose_subsequence():
    # TODO(koz4k): Weigh rollouts by their lengths so sampling is uniform over
    # frames and not rollouts.
    rollout = random.choice(rollouts)
    try:
      from_index = random.randrange(len(rollout) - subsequence_length + 1)
    except ValueError:
      # Rollout too short; repeat.
      return choose_subsequence()
    return rollout[from_index:(from_index + subsequence_length)]

  return [choose_subsequence() for _ in range(num_subsequences)]


def train_supervised(problem, model_name, hparams, data_dir, output_dir,epoch,
                     train_steps, eval_steps, local_eval_frequency=None, schedule="continuous_train_and_eval"):
                     #schedule="continuous_train_and_eval"):
  """Train supervised."""
  if local_eval_frequency is None:
    local_eval_frequency = FLAGS.local_eval_frequency

  exp_fn = trainer_lib.create_experiment_fn(
      model_name, problem, data_dir, train_steps, eval_steps,
      min_eval_frequency=local_eval_frequency
  )
  run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir)
  exp = exp_fn(run_config, hparams)
  metrics = getattr(exp, schedule)(epoch=epoch)

  return metrics


def train_agent(real_env, learner, world_model_dir, hparams, epoch):
  """Train the PPO agent in the simulated environment."""
  initial_frame_chooser = rl_utils.make_initial_frame_chooser(
      real_env, hparams.frame_stack_size, hparams.simulation_random_starts,
      hparams.simulation_flip_first_random_for_beginning
  )
  env_fn = rl.make_simulated_env_fn_from_hparams(
      real_env, hparams, batch_size=hparams.simulated_batch_size,
      initial_frame_chooser=initial_frame_chooser, model_dir=world_model_dir,
      sim_video_dir=os.path.join(
          learner.agent_model_dir, "sim_videos_{}".format(epoch)
      )
  )
  base_algo_str = hparams.base_algo
  train_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
  if hparams.wm_policy_param_sharing:
    train_hparams.optimizer_zero_grads = True

  rl_utils.update_hparams_from_hparams(
      train_hparams, hparams, base_algo_str + "_"
  )

  final_epoch = hparams.epochs - 1
  is_special_epoch = (epoch + 3) == final_epoch//2 or (epoch + 7) == final_epoch//2 # same condition halfway as well
  is_special_epoch = is_special_epoch or (epoch + 3) == final_epoch or (epoch + 7) == final_epoch
  is_special_epoch = is_special_epoch or (epoch == 1)  or (epoch == (final_epoch//2 +2))# Make 1 special too (in both halves; so iter 1 and iter 16).
  is_final_epoch = epoch == final_epoch
  env_step_multiplier = 3 if (is_final_epoch or epoch == hparams.epochs) else 2 if is_special_epoch else 1

  learner.train(
      env_fn, train_hparams, simulated=True, save_continuously=True,
      epoch=epoch, env_step_multiplier=env_step_multiplier
  )


def train_agent_real_env(env, learner, hparams, epoch):
  """Train the PPO agent in the real environment."""
  base_algo_str = hparams.base_algo

  train_hparams = trainer_lib.create_hparams(hparams.base_algo_params)

  rl_utils.update_hparams_from_hparams(
      train_hparams, hparams, "real_" + base_algo_str + "_"
  )
  if hparams.wm_policy_param_sharing:
    train_hparams.optimizer_zero_grads = True

  env_fn = rl.make_real_env_fn(env)
  num_env_steps = real_env_step_increment(hparams, epoch)
  learner.train(
      env_fn,
      train_hparams,
      simulated=False,
      save_continuously=False,
      epoch=epoch,
      sampling_temp=hparams.real_sampling_temp,
      num_env_steps=num_env_steps,
  )
  # Save unfinished rollouts to history.
  env.reset()


def train_world_model(
    env, data_dir, output_dir, hparams, world_model_steps_num, epoch, directories
):
  """Train the world model on problem_name."""

  world_model_steps_num += world_model_step_increment(hparams, is_initial_epoch=(epoch == 0))

  from multiprocessing import Process

  def run_particle_supervised_training(particle_index,world_model_steps_num,epoch, d_id) :
    command_prefix = 'python -m tensor2tensor.rl.train_supervised -output_dir={}  --loop_hparams_set=rlmb_base '.format(FLAGS.output_dir)
    command_args = '--particle_id_wm {} --epoch_number_wm {} --steps_done_wm {} --d_id {} '.format(particle_index,epoch, world_model_steps_num, d_id)
    game_specific_args = '--loop_hparams=game={},model_train_steps={}'.format(hparams.game, hparams.model_train_steps)
    output_log_cmd = ' >  {}'.format(os.path.join(format(FLAGS.output_dir), 'wm_{}.log'.format(particle_index)))#$HOME/tensor2tensor/rl_vt/wm_{}.log'.format(particle_index)
    command = command_prefix + command_args + game_specific_args + output_log_cmd
    os.system(command)

  def run_particle(particle_id) :
    out_dir = directories["world_model"] if particle_id == 0 else directories["world_model_{}".format(particle_id)]
    restarter = Restarter("world_model", out_dir, world_model_steps_num)
    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    model_hparams.learning_rate = model_hparams.learning_rate_constant
    if epoch > 0:
      model_hparams.learning_rate *= hparams.learning_rate_bump
    if hparams.wm_policy_param_sharing:
      model_hparams.optimizer_zero_grads = True

    if restarter.should_skip:
      return world_model_steps_num
    with restarter.training_loop():
      return train_supervised(
        problem=env,
        model_name=hparams.generative_model,
        hparams=model_hparams,
        epoch=epoch, 
        data_dir=data_dir,
        output_dir=out_dir,
        train_steps=restarter.target_global_step,
        eval_steps=100,
        local_eval_frequency=2000
      )


  losses = []
  first_process = []
  num_particles = int(hparams.num_particles)
  runs_per_gpu = 3
  total_gpus = 1
  total_parallel_runs = total_gpus * runs_per_gpu
  total_sequential_runs =  num_particles//total_parallel_runs if num_particles % total_parallel_runs == 0 else num_particles//total_parallel_runs + 1
  for run in range(total_sequential_runs) :
    process_list = []
    for particle in range(total_parallel_runs*run + 1, min(total_parallel_runs*(run+1), num_particles)) :     
      d_id = str(int(hparams.d_id) + int(particle)%total_gpus)
      process_list.append(Process(target=run_particle_supervised_training, args=(particle,world_model_steps_num,epoch, d_id)))
    try :
      for i, p in enumerate(process_list) :
        p.start()

      metrics = run_particle(int(total_parallel_runs*run))
      first_process.append(int(total_parallel_runs*run))
      for p in process_list :
        p.join()
    finally : 
      for p in process_list :
        p.terminate() 
    if num_particles > 1 :
      losses.append(metrics['loss']) # probability of particle total_parallel_runs*run
    '''
    Read metrics and append loss values
    '''
    for particle in range(total_parallel_runs*run + 1, min(total_parallel_runs*(run+1), num_particles)) :
      with open(os.path.join(output_dir + '_{}'.format(particle), 'loss_value'), 'rb+') as f:
        metric_i = pickle.load(f)
        losses.append(metric_i['loss'])
  
  if num_particles > 1 :
    unnormalized_probabilities = np.exp(-np.array(losses))   # prob = e(-losses)
    normalizing_constant = np.sum(unnormalized_probabilities) 
    probabilities = unnormalized_probabilities/normalizing_constant
    next_model_id = np.random.choice(num_particles, p=probabilities)
  else : 
    next_model_id = 0 # only 1 particle
  # GET NEXT MODEL
  return world_model_steps_num, next_model_id, first_process


def load_metrics(event_dir, epoch):
  """Loads metrics for this epoch if they have already been written.

  This reads the entire event file but it's small with just per-epoch metrics.

  Args:
    event_dir: TODO(koz4k): Document this.
    epoch: TODO(koz4k): Document this.

  Returns:
    metrics.
  """
  metrics = {}
  for filename in tf.gfile.ListDirectory(event_dir):
    path = os.path.join(event_dir, filename)
    for event in tf.train.summary_iterator(path):
      if event.step == epoch and event.HasField("summary"):
        value = event.summary.value[0]
        metrics[value.tag] = value.simple_value
  return metrics

def log_time_diff(start_time, name, epoch) :
  fname = os.path.join(FLAGS.output_dir, 'time_log_{}'.format(name))
  with open(fname, "a+") as f : 
    f.write("Epoch: {} , Time {} \n".format(epoch, time.time() - start_time))

def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
  """Run the main training loop."""
  if report_fn:
    assert report_metric is not None

  # Directories
  subdirectories = [
      "data", "tmp", "world_model", ("world_model", "debug_videos"),
      "policy", "eval_metrics"
  ]

  num_particles = int(hparams.num_particles)
  for p in range(1,num_particles) :
    subdirectories.append('world_model_{}'.format(p))
    subdirectories.append(('world_model_{}'.format(p), "debug_videos"))
  directories = setup_directories(output_dir, subdirectories)
  epoch = -1
  data_dir = directories["data"]
  env = rl_utils.setup_env(
      hparams, batch_size=hparams.real_batch_size,
      max_num_noops=hparams.max_num_noops,
      rl_env_max_episode_steps=hparams.rl_env_max_episode_steps
  )
  env.start_new_epoch(epoch, data_dir)

  if hparams.wm_policy_param_sharing:
    policy_model_dir = directories["world_model"]
  else:
    policy_model_dir = directories["policy"]
  learner = rl_utils.LEARNERS[hparams.base_algo](
      hparams.frame_stack_size, policy_model_dir,
      policy_model_dir, hparams.epochs
  )

  # Timing log function
  log_relative_time = make_relative_timing_fn()

  # Per-epoch state
  epoch_metrics = []
  metrics = {}
  # Collect data from the real environment.
  policy_model_dir = directories["policy"]
  tf.logging.info("Initial training of the policy in real environment.")
  start = time.time()
  train_agent_real_env(env, learner, hparams, epoch)
  log_time_diff(start, 'real_env', epoch)
  metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
      env.current_epoch_rollouts(), clipped=True
  )
  tf.logging.info("Mean training reward (initial): {}".format(
      metrics["mean_reward/train/clipped"]
  ))
  env.generate_data(data_dir)

  eval_metrics_writer = tf.summary.FileWriter(
      directories["eval_metrics"]
  )

  world_model_steps_num = 0

  for epoch in range(hparams.epochs):
    log = make_log_fn(epoch, log_relative_time)

    # Train world model
    log("Training world model")
    start = time.time()
    world_model_steps_num, next_model_id, first_process = train_world_model(
        env, data_dir, directories["world_model"], hparams,
        world_model_steps_num, epoch, directories
    )
    log_time_diff(start, 'world_model', epoch)
    # Train agent
    log("Training policy in simulated environment using world model {}.".format(next_model_id))
    next_wm = "world_model_{}".format(next_model_id) if next_model_id != 0 else "world_model"
    start = time.time()
    train_agent(env, learner, directories[next_wm], hparams, epoch)
    log_time_diff(start, 'sim_env', epoch)
    env.start_new_epoch(epoch, data_dir)

    # Train agent on real env (short)
    log("Training policy in real environment.")
    start = time.time()
    train_agent_real_env(env, learner, hparams, epoch)
    log_time_diff(start, 'real_env', epoch)
    if hparams.stop_loop_early:
      return 0.0

    env.generate_data(data_dir)

    metrics = load_metrics(directories["eval_metrics"], epoch)
    if metrics:
      # Skip eval if metrics have already been written for this epoch. Otherwise
      # we'd overwrite them with wrong data.
      log("Metrics found for this epoch, skipping evaluation.")
    else:
      metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
          env.current_epoch_rollouts(), clipped=True
      )
      log("Mean training reward: {}".format(
          metrics["mean_reward/train/clipped"]
      ))

      eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir)
      log("Agent eval metrics:\n{}".format(pprint.pformat(eval_metrics)))
      metrics.update(eval_metrics)
      # FIX THIS
      if hparams.eval_world_model:
        for particle in range(num_particles) : 
          if particle in first_process : #wm metrics not yet calculated; calculate it
            if particle == 0 :
              wm_dir_str = "world_model"
            else :
              wm_dir_str = "world_model_{}".format(particle)
              
            debug_video_path = os.path.join(
              directories[wm_dir_str, "debug_videos"],
              "{}.avi".format(env.current_epoch))

            wm_metrics = rl_utils.evaluate_world_model(
                env, hparams, directories[wm_dir_str], debug_video_path
            )
            log("World model {} eval metrics:\n{}".format(particle, pprint.pformat(wm_metrics)))
            metrics.update(wm_metrics)
          else : 
            with open(os.path.join(directories["world_model_{}".format(particle)], 'reward_metrics'), 'rb+') as f :
              wm_particle_metric = pickle.load(f)
              new_wm_particle_metric = {key.split()[0]+'_{}'.format(particle)+''.join(key.split()[1:]) : value for (key,value) in wm_particle_metric.items()}         
              metrics.update(new_wm_particle_metric)

      rl_utils.summarize_metrics(eval_metrics_writer, metrics, epoch)

      # Report metrics
      if report_fn:
        if report_metric == "mean_reward":
          metric_name = rl_utils.get_metric_name(
              sampling_temp=hparams.eval_sampling_temps[0],
              max_num_noops=hparams.eval_max_num_noops,
              clipped=False
          )
          report_fn(eval_metrics[metric_name], epoch)
        else:
          report_fn(eval_metrics[report_metric], epoch)

    epoch_metrics.append(metrics)


  train_agent(env, learner, directories["world_model"], hparams, hparams.epochs) # final run with full model
  metrics = load_metrics(directories["eval_metrics"], hparams.epochs)
  eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir)
  metrics.update(eval_metrics)
  rl_utils.summarize_metrics(eval_metrics_writer, metrics, hparams.epochs)
  epoch_metrics.append(metrics)

  # Return the evaluation metrics from the final epoch
  return epoch_metrics[-1]


def main(_):
  hp = trainer_model_based_params.create_loop_hparams()
  os.environ["CUDA_VISIBLE_DEVICES"] = hp.d_id
  assert not FLAGS.job_dir_to_evaluate
  training_loop(hp, FLAGS.output_dir)


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()
