# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module defining how to run the experiments."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time

from absl import logging
import functools

from dopamine.discrete_domains import atari_lib
from dopamine.discrete_domains import checkpointer
from dopamine.discrete_domains import iteration_statistics
from dopamine.discrete_domains import logger

from multi_task_atari import multi_task_dqn_agent
from multi_task_atari import multi_task_offline_dqn_agent
from multi_task_atari import multi_task_atari_env
from multi_task_atari import atari_config

from multi_task_atari.fine_tuning import ft_cql_agent

import numpy as np
import tensorflow as tf
import pickle

import jax

# tf.config.set_visible_devices([], "GPU")

import gin


def load_gin_configs(gin_files, gin_bindings):
  """Loads gin configuration files.

  Args:
    gin_files: list, of paths to the gin configuration files for this
      experiment.
    gin_bindings: list, of gin parameter bindings to override the values in
      the config files.
  """
  gin.parse_config_files_and_bindings(gin_files,
                                      bindings=gin_bindings,
                                      skip_unknown=False)


@gin.configurable
def create_agent(sess, environment, agent_name=None, summary_writer=None,
                 debug_mode=False, num_devices=1, game_names=('Asterix',)):
  """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
  assert agent_name is not None
  if not debug_mode:
    summary_writer = None
  if agent_name == 'multi_task_dqn':
    return multi_task_dqn_agent.MultiTaskJaxDQNAgent(
        num_actions=environment.action_space.n,
        summary_writer=summary_writer,
        num_devices=num_devices,
        num_games=len(game_names))
  elif agent_name == 'multi_task_offline_dqn':
    return multi_task_offline_dqn_agent.OfflineMultiTaskJaxDQNAgent(
        num_actions=environment.action_space.n,
        summary_writer=summary_writer,
        num_devices=num_devices,
        game_names=game_names,
        num_games=len(game_names))
  else:
    raise ValueError('Unknown agent: {}'.format(agent_name))


@gin.configurable
def create_runner(base_dir, schedule='continuous_train_and_eval'):
  """Creates an experiment Runner.

  Args:
    base_dir: str, base directory for hosting all subdirectories.
    schedule: string, which type of Runner to use.

  Returns:
    runner: A `Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
  assert base_dir is not None
  # Continuously runs training and evaluation until max num_iterations is hit.
  if schedule == 'continuous_train_and_eval':
    return Runner(base_dir, create_agent)
  # Continuously runs training until max num_iterations is hit.
  elif schedule == 'continuous_train':
    return TrainRunner(base_dir, create_agent)
  else:
    raise ValueError('Unknown schedule: {}'.format(schedule))


@gin.configurable
class MultiTaskRunner(object):

  def __init__(self,
               base_dir,
               create_agent_fn,
               create_environment_fn=multi_task_atari_env.create_multi_task_atari_environment,
               checkpoint_file_prefix='ckpt',
               logging_file_prefix='log',
               log_every_n=1,
               num_iterations=200,
               training_steps=250000,
               evaluation_steps=125000,
               max_steps_per_episode=27000,
               game_names=('Asterix',),
               clip_rewards=True,
               only_train=False,
               only_eval=False,
               use_single_game_action_space=False,
               no_agent=False,
               sticky_actions=True):
    """Initialize the Runner object in charge of running a full experiment.

    Args:
      base_dir: str, the base directory to host all required sub-directories.
      create_agent_fn: A function that takes as args a Tensorflow session and an
        environment, and returns an agent.
      create_environment_fn: A function which receives a problem name and
        creates a Gym environment for that problem (e.g. an Atari 2600 game).
      checkpoint_file_prefix: str, the prefix to use for checkpoint files.
      logging_file_prefix: str, prefix to use for the log files.
      log_every_n: int, the frequency for writing logs.
      num_iterations: int, the iteration number threshold (must be greater than
        start_iteration).
      training_steps: int, the number of training steps to perform.
      evaluation_steps: int, the number of evaluation steps to perform.
      max_steps_per_episode: int, maximum number of steps after which an episode
        terminates.
      clip_rewards: bool, whether to clip rewards in [-1, 1].

    This constructor will take the following actions:
    - Initialize an environment.
    - Initialize a `tf.compat.v1.Session`.
    - Initialize a logger.
    - Initialize an agent.
    - Reload from the latest checkpoint, if available, and initialize the
      Checkpointer object.
    """
    assert base_dir is not None
    tf.compat.v1.enable_v2_behavior()
    physical_devices = tf.config.list_physical_devices('GPU')
    print ('All devices visible to tf: ', physical_devices)
    for device_name in physical_devices:
      tf.config.experimental.set_memory_growth(device_name, True)

    self._logging_file_prefix = logging_file_prefix
    self._log_every_n = log_every_n
    self._num_iterations = num_iterations
    self._training_steps = training_steps
    self._evaluation_steps = evaluation_steps
    self._max_steps_per_episode = max_steps_per_episode
    self._base_dir = base_dir
    self._clip_rewards = clip_rewards
    self._use_single_game_action_space = use_single_game_action_space
    self._create_directories()
    self._summary_writer = tf.summary.create_file_writer(self._base_dir)

    self._only_train = only_train

    self._game_names = game_names
    print ('Game names in runner: ', self._game_names)
    print ('Making env with sticky actions: ', sticky_actions)
    self._environment = create_environment_fn(
        game_names=self._game_names,
        use_single_game_action_space=self._use_single_game_action_space,
        sticky_actions=sticky_actions)

    if not no_agent:
      self._agent = create_agent_fn(None, self._environment,
                                    summary_writer=self._summary_writer,
                                    game_names=self._game_names,
                                    use_single_game_action_space=use_single_game_action_space)
    else:
      self._agent = None
      self._create_agent_fn = create_agent_fn

    self._checkpoint_file_prefix = checkpoint_file_prefix
    if not only_eval:
      self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)

  def _create_directories(self):
    """Create necessary sub-directories."""
    self._checkpoint_dir = os.path.join(self._base_dir, 'checkpoints')
    self._logger = logger.Logger(os.path.join(self._base_dir, 'logs'))

  def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
    """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the agent is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
    self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                   checkpoint_file_prefix)
    self._start_iteration = 0
    # Check if checkpoint exists. Note that the existence of checkpoint 0 means
    # that we have finished iteration 0 (so we will start from iteration 1).
    latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
        self._checkpoint_dir)
    if latest_checkpoint_version >= 0:
      experiment_data = self._checkpointer.load_checkpoint(
          latest_checkpoint_version)
      if self._agent.unbundle(
          self._checkpoint_dir, latest_checkpoint_version, experiment_data):
        if experiment_data is not None:
          assert 'logs' in experiment_data
          assert 'current_iteration' in experiment_data
          self._logger.data = experiment_data['logs']
          self._start_iteration = experiment_data['current_iteration'] + 1
        logging.info('Reloaded checkpoint and will start from iteration %d',
                     self._start_iteration)


  def _initialize_episode(self):
    """Initialization for a new episode.

    Returns:
      action: int, the initial action chosen by the agent.
    """
    initial_observation, env_idx = self._environment.reset()
    return self._agent.begin_episode(initial_observation, game_index=env_idx)

  def _run_one_step(self, action):
    """Executes a single step in the environment.

    Args:
      action: int, the action to perform in the environment.

    Returns:
      The observation, reward, and is_terminal values returned from the
        environment.
    """
    observation, reward, is_terminal, _ = self._environment.step(action)
    return observation, reward, is_terminal

  def _end_episode(self, reward, terminal=True):
    """Finalizes an episode run.

    Args:
      reward: float, the last reward from the environment.
      terminal: bool, whether the last state-action led to a terminal state.
    """
    self._agent.end_episode(reward, terminal)

  def _run_one_episode(self):
    """Executes a full trajectory of the agent interacting with the environment.

    Returns:
      The number of steps taken and the total reward.
    """
    step_number = 0
    total_reward = 0.

    action = self._initialize_episode()
    game_idx = self._environment.env_idx
    is_terminal = False

    # Keep interacting until we reach a terminal state.
    while True:
      observation, reward, is_terminal = self._run_one_step(action)

      total_reward += reward
      step_number += 1

      if self._clip_rewards:
        # Perform reward clipping.
        reward = np.clip(reward, -1, 1)

      if (self._environment.game_over or
          step_number == self._max_steps_per_episode):
        # Stop the run loop once we reach the true end of episode.
        break
      elif is_terminal:
        # If we lose a life but the episode is not over, signal an artificial
        # end of episode to the agent.
        self._end_episode(reward, is_terminal)
        action = self._agent.begin_episode(observation, game_index=game_idx)
      else:
        action = self._agent.step(reward, observation, game_index=game_idx)

    self._end_episode(reward, is_terminal)

    return step_number, total_reward, game_idx

  def _run_one_phase(self, min_steps, statistics, run_mode_str):
    """Runs the agent/environment loop until a desired number of steps.

    We follow the Machado et al., 2017 convention of running full episodes,
    and terminating once we've run a minimum number of steps.

    Args:
      min_steps: int, minimum number of steps to generate in this phase.
      statistics: `IterationStatistics` object which records the experimental
        results.
      run_mode_str: str, describes the run mode for this agent.

    Returns:
      Tuple containing the number of steps taken in this phase (int), the sum of
        returns (float), and the number of episodes performed (int).
    """
    step_count = [0.0,] * len(self._game_names)
    num_episodes = [0,] * len(self._game_names)
    sum_returns = [0.,] * len(self._game_names)

    for idx in range(len(self._game_names)):
      self._environment._set_env(idx)
      print ('Coming here, and resetting game name to ', idx, self._game_names[idx], self._game_names)
      while step_count[idx] < min_steps:
        episode_length, episode_return, episode_game_id = self._run_one_episode()
        assert episode_game_id == idx, "Game indices don't match"
        statistics.append({
            '{}_{}_episode_lengths'.format(
                self._game_names[episode_game_id], run_mode_str): episode_length,
            '{}_{}_episode_returns'.format(
                self._game_names[episode_game_id], run_mode_str): episode_return
        })
        step_count[idx] += episode_length
        sum_returns[idx] += episode_return
        num_episodes[idx] += 1
        # We use sys.stdout.write instead of logging so as to flush frequently
        # without generating a line break.
        sys.stdout.write('Games: {} '.format(self._game_names[idx]) +
                         'Steps executed: {} '.format(step_count[idx]) +
                         'Episode length: {} '.format(episode_length) +
                         'Return: {}\r'.format(episode_return))
        sys.stdout.flush()
    return step_count, sum_returns, num_episodes

  def _run_train_phase(self, statistics):
    """Run training phase.

    Args:
      statistics: `IterationStatistics` object which records the experimental
        results. Note - This object is modified by this method.

    Returns:
      num_episodes: int, The number of episodes run in this phase.
      average_reward: float, The average reward generated in this phase.
      average_steps_per_second: float, The average number of steps per second.
    """
    # Run training for standard online setting
    # Perform the training phase, during which the agent learns.
    self._agent.eval_mode = False
    start_time = time.time()
    number_steps, sum_returns, num_episodes = self._run_one_phase(
        self._training_steps, statistics, 'train')

    if isinstance(num_episodes, list):
      num_episodes = num_episodes[0]
      sum_returns = sum_returns[0]
      number_steps = number_steps[0]

    average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0
    statistics.append({'train_average_return': average_return})
    time_delta = time.time() - start_time
    average_steps_per_second = number_steps / time_delta
    statistics.append(
        {'train_average_steps_per_second': average_steps_per_second})
    logging.info('Average undiscounted return per training episode: %.2f',
                 average_return)
    logging.info('Average training steps per second: %.2f',
                 average_steps_per_second)

    if isinstance(num_episodes, int):
      num_episodes = [num_episodes,]
      average_return = [average_return,]
      average_steps_per_second = [average_steps_per_second,]

    return num_episodes, average_return, average_steps_per_second

  def _run_eval_phase(self, statistics):
    """Run evaluation phase.

    Args:
      statistics: `IterationStatistics` object which records the experimental
        results. Note - This object is modified by this method.

    Returns:
      num_episodes: int, The number of episodes run in this phase.
      average_reward: float, The average reward generated in this phase.
    """
    # Perform the evaluation phase -- no learning.
    if not self._only_train:
      self._agent.eval_mode = True
      self._agent.unreplicate_params()
      _, sum_returns, num_episodes = self._run_one_phase(
          self._evaluation_steps, statistics, 'eval')
      average_returns = []
      for idx in range(len(self._game_names)):
        average_return_idx = sum_returns[idx] / num_episodes[idx] if num_episodes[idx] > 0 else 0.0
        logging.info('Average undiscounted return per evaluation episode: %.2f',
                    average_return_idx)
        statistics.append({'{}_eval_average_return'.format(self._game_names[idx]): average_return_idx})
        average_returns.append(average_return_idx)

      self._agent.replicate_params_to_devices()
    else:
      num_episodes = [0.0,] * len(self._game_names)
      average_returns = [0.0,] * len(self._game_names)

    return num_episodes, average_returns


  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction.

    An iteration involves running several episodes until a certain number of
    steps are obtained. The interleaving of train/eval phases implemented here
    are to match the implementation of (Mnih et al., 2015).

    Args:
      iteration: int, current iteration number, used as a global_step for saving
        Tensorboard summaries.

    Returns:
      A dict containing summary statistics for this iteration.
    """
    statistics = iteration_statistics.IterationStatistics()
    logging.info('Starting iteration %d', iteration)
    num_episodes_train, average_reward_train, average_steps_per_second = (
        self._run_train_phase(statistics))
    num_episodes_eval, average_reward_eval = self._run_eval_phase(
        statistics)

    self._save_tensorboard_summaries(iteration, num_episodes_train,
                                     average_reward_train, num_episodes_eval,
                                     average_reward_eval,
                                     average_steps_per_second)
    return statistics.data_lists


  def _save_tensorboard_summaries(self, iteration,
                                  num_episodes_train,
                                  average_reward_train,
                                  num_episodes_eval,
                                  average_reward_eval,
                                  average_steps_per_second):
    """Save statistics as tensorboard summaries.

    Args:
      iteration: int, The current iteration number.
      num_episodes_train: int, number of training episodes run.
      average_reward_train: float, The average training reward.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
      average_steps_per_second: float, The average number of steps per second.
    """
    summary_list = []
    with self._summary_writer.as_default():
      for idx in range(len(self._game_names)):
        tf.summary.scalar(
            'Train_' + self._game_names[idx] + '/NumEpisodes',
            num_episodes_train[idx],
            step=iteration)
        tf.summary.scalar(
            'Train_' + self._game_names[idx] + '/AverageReturns',
            average_reward_train[idx],
            step=iteration)
        tf.summary.scalar(
            'Train_' + self._game_names[idx] + '/AverageStepsPerSecond',
            average_steps_per_second[idx],
            step=iteration)
        tf.summary.scalar(
            'Eval_' + str(self._game_names[idx]) + '/NumEpisodes',
            num_episodes_eval[idx],
            step=iteration)
        tf.summary.scalar(
            'Eval_' + str(self._game_names[idx]) + '/AverageReturns',
            average_reward_eval[idx],
            step=iteration)


  def _log_experiment(self, iteration, statistics):
    """Records the results of the current iteration.

    Args:
      iteration: int, iteration number.
      statistics: `IterationStatistics` object containing statistics to log.
    """
    self._logger['iteration_{:d}'.format(iteration)] = statistics
    if iteration % self._log_every_n == 0:
      self._logger.log_to_file(self._logging_file_prefix, iteration)

  def _checkpoint_experiment(self, iteration):
    """Checkpoint experiment data.

    Args:
      iteration: int, iteration number for checkpointing.
    """
    experiment_data = self._agent.bundle_and_checkpoint(self._checkpoint_dir,
                                                        iteration)
    if experiment_data:
      experiment_data['current_iteration'] = iteration
      experiment_data['logs'] = self._logger.data
      self._checkpointer.save_checkpoint(iteration, experiment_data)

  def run_experiment(self):
    """Runs a full experiment, spread over multiple iterations."""
    logging.info('Beginning training...')
    if self._num_iterations <= self._start_iteration:
      logging.warning('num_iterations (%d) < start_iteration(%d)',
                      self._num_iterations, self._start_iteration)
      return

    for iteration in range(self._start_iteration, self._num_iterations):
      statistics = self._run_one_iteration(iteration)
      self._log_experiment(iteration, statistics)
      self._checkpoint_experiment(iteration)
    self._summary_writer.flush()


@gin.configurable
class MultiTaskFinetuningRunner(MultiTaskRunner):
  """Object that handles running finetuning on multi-task Atari."""

  def __init__(self, *args, **kwargs):
    super(MultiTaskFinetuningRunner, self).__init__(
        *args, only_train=True, **kwargs)

    # Do some initial setup
    self.finetuning_mode = self._agent.finetuning_mode
    print('Finetuning mode in the Multi-task runner: ', self.finetuning_mode)

    if 'online' in self.finetuning_mode:
      # We can simply use the standard runner from above.
      self._only_train = False
      self._agent.eval_mode = False
      self._run_train_phase = super()._run_train_phase
      self.run_experiment = super().run_experiment
      self._run_one_iteration = self._run_one_train_iteration
      self._save_tensorboard_summaries = super()._save_tensorboard_summaries

  def _run_one_train_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    logging.info('Starting iteration %d', iteration)
    num_episodes_train, average_reward_train, average_steps_per_second = (
        self._run_train_phase(statistics))
    self._save_tensorboard_summaries(iteration, num_episodes_train,
                                     average_reward_train, [0], [0],
                                     average_steps_per_second)
    return statistics.data_lists

  def _run_train_phase(self):
    """Run training phase."""
    self._agent.eval_mode = False
    start_time = time.time()

    for i in range(self._training_steps):
      if i % 100 == 0:
        # We use sys.stdout.write instead of logging so as to flush frequently
        # without generating a line break.
        sys.stdout.write('Training step: {}/{}\r'.format(
            i, self._training_steps))
        sys.stdout.flush()

      self._agent.train_step()

    time_delta = time.time() - start_time
    logging.info('Average training steps per second: %.2f',
                 self._training_steps / time_delta)

  def run_experiment(self):
    """Runs a full experiment, spread over multiple iterations."""
    logging.info('Beginning training...')
    if self._num_iterations <= self._start_iteration:
      logging.warning('num_iterations (%d) < start_iteration(%d)',
                      self._num_iterations, self._start_iteration)
      return

    for iteration in range(self._start_iteration, self._num_iterations):
      statistics = self._run_one_iteration(iteration)
      if jax.process_index() == 0:
        # Only log and save checkpoints when process_index == 0
        self._log_experiment(iteration, statistics)
        self._checkpoint_experiment(iteration)
    self._summary_writer.flush()

  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    logging.info('Starting iteration %d', iteration)
    # Reload the replay buffer at every iteration
    self._agent.reload_data()
    self._run_train_phase()

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)

    self._save_tensorboard_summaries(iteration, num_episodes_eval,
                                     average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration, num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.

    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    if not self._only_train:
      with self._summary_writer.as_default():
        for idx in range(len(self._game_names)):
          tf.summary.scalar(
            'Eval_' + str(self._game_names[idx]) + '/NumEpisodes',
            num_episodes_eval[idx], step=iteration
          )
          tf.summary.scalar(
            'Eval_' + str(self._game_names[idx]) + '/AverageReturns',
            average_reward_eval[idx], step=iteration
          )


@gin.configurable
class MultiTaskFixedReplayEvalRunner(MultiTaskRunner):
  """Object that handles running Dopamine evaluation for
     multi-task Atari in a different thread."""

  def __init__(self, *args, eval_job_index=0,
               total_eval_jobs=1, full_games=(), sticky_actions=True,
               eval_ckpt_number=None,
               eval_ckpt_dir=None, eval_begins_at=-1, **kwargs):
    self._eval_job_index = eval_job_index
    self._total_eval_jobs = total_eval_jobs
    self._eval_ckpt_number = eval_ckpt_number
    self._eval_ckpt_dir = eval_ckpt_dir
    self._eval_begins_at = eval_begins_at

    print ('Asked to evaluate ckpt and directory: ',
           self._eval_ckpt_dir, self._eval_ckpt_number)

    super(MultiTaskFixedReplayEvalRunner, self).__init__(
        *args, only_eval=True, no_agent=True,
        sticky_actions=sticky_actions, **kwargs)

    # Which directory to evaluate
    if self._eval_ckpt_dir is not None and self._eval_ckpt_number is not None:
      self._checkpoint_dir = os.path.join(self._eval_ckpt_dir, 'checkpoints')

    if len(self._game_names) == 1:
      self._game_name_to_append = self._game_names[0]
    else:
      self._game_name_to_append = ''

    self._logging_file_prefix = 'log_eval_' + str(eval_job_index) + '_' + self._game_name_to_append
    self._logger = logger.Logger(
        os.path.join(
            self._base_dir, 'log_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append))

    # Get the full game list for the current data
    self.full_games = full_games
    print ('[Eval Runner]', self.full_games, self._game_names)
    if len(self.full_games) > len(self._game_names) and len(self._game_names) == 1:
      self.current_game_index = self.full_games.index(self._game_names[0])
      self._agent = self._create_agent_fn(
          None, self._environment,
          summary_writer=self._summary_writer,
          game_names=self._game_names,
          use_single_game_action_space=self._use_single_game_action_space,
          override_game_index=self.current_game_index,
          override_num_games=len(self.full_games))
    else:
      self._agent = self._create_agent_fn(
          None, self._environment,
          summary_writer=self._summary_writer,
          game_names=self._game_names,
          use_single_game_action_space=self._use_single_game_action_space)

    # Create a new summary writer to write to a new directory to prevent
    # any form of interference happening
    summary_base_dir = os.path.join(self._base_dir, 'eval_' + self._game_name_to_append)
    self._summary_writer = tf.summary.create_file_writer(summary_base_dir)

    print ('[Eval Runner] Current eval job: ', self._eval_job_index,
           self._total_eval_jobs,
           self._logging_file_prefix)

    # Also need a mechanism for checkpoint loading
    self._initialize_checkpointer_and_maybe_resume(
        'eval_ckpt_' + str(eval_job_index) + '_' + self._game_name_to_append)
    print (
      'Had to reload eval runner, starting from: ', self._start_iteration)

    if self._eval_begins_at >= self._start_iteration:
      print ('starting to eval from checkpoint: ', self._eval_begins_at)
      self._start_iteration = self._eval_begins_at


  def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
    print ('Trying to initialize evaluator checkpoint')
    _checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
        checkpoint_file_prefix,
        sentinel_file_identifier='checkpoint_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append)
    self._start_iteration = self._eval_job_index
    # Check if checkpoint exists. Note that the existence of checkpoint 0 means
    # that we have finished iteration 0 (so we will start from iteration 1).
    latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
        self._checkpoint_dir,
        sentinel_file_identifier='checkpoint_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append)

    print ('Latest evaluator checkpoint found: ', latest_checkpoint_version)
    if latest_checkpoint_version >= 0:
      experiment_data = _checkpointer.load_checkpoint(
          latest_checkpoint_version)
      if self._agent.unbundle(
          self._checkpoint_dir, latest_checkpoint_version, experiment_data):
        if experiment_data is not None:
          assert 'log_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append in experiment_data
          assert 'current_iteration' in experiment_data
          self._logger.data = experiment_data[
              'log_eval_' + str(self._eval_job_index) +  '_' + self._game_name_to_append]
          self._start_iteration = experiment_data[
              'current_iteration'] + self._total_eval_jobs
        logging.info('Reloaded start iteration for evaluator will start from iteration %d',
                      self._start_iteration)

  def _run_train_phase(self):
    logging.info('Skipping training steps, since this is evaluation runner: %.2f',
                 self._training_steps)

  def _run_one_iteration(self, iteration):
    statistics = iteration_statistics.IterationStatistics()
    logging.info('[Eval Runner] Starting iteration %d', iteration)

    # Load a given checkpoint, continuously untill we find it
    # this loop will try untill 60 hours or so to find one, so make sure that
    # the training process doesn't take more than 6 hours to generate one
    # checkpoint.
    print ('Trying to load', self._checkpoint_file_prefix, iteration)
    for jdx in range(10000):
      checkpoint_found = self.load_given_checkpoint(self._checkpoint_file_prefix, iteration)
      if not checkpoint_found:
        print ('[Eval Runner] Checkpoint not yet found:', iteration, jdx)
        time.sleep(25)
      else:
        break

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)
    self._save_tensorboard_summaries(iteration, num_episodes_eval,
                                     average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration, num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.

    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    with self._summary_writer.as_default():
      for idx in range(len(self._game_names)):
        print ('Writing tb summaries for: ', self._game_names[idx], len(self._game_names))
        tf.summary.scalar(
          'Eval_Runner_' + str(self._eval_job_index) + '_' + str(self._game_names[idx]) + '/NumEpisodes',
          num_episodes_eval[idx], step=iteration
        )
        tf.summary.scalar(
          'Eval_Runner_' + str(self._eval_job_index) + '_' + str(self._game_names[idx]) + '/AverageReturns',
          average_reward_eval[idx], step=iteration
        )

  def load_given_checkpoint(self, checkpoint_file_prefix, ckpt_iteration):
    """Reloads a given_checkpoint
    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
    self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                   checkpoint_file_prefix)
    # Check if checkpoint exists. Note that the existence of checkpoint 0 means
    # that we have finished iteration 0 (so we will start from iteration 1).
    latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
        self._checkpoint_dir)

    checkpoint_found = False
    print ('Latest checkpoint: ', latest_checkpoint_version)
    if latest_checkpoint_version >= 0 and\
        ckpt_iteration <= latest_checkpoint_version:
      print ('checkpoint found.....')
      # Try to load checkpoint a few times to account for the fact that
      # checkpoint may not be ready when we try to read it
      for i in range(1000):
        try:
          experiment_data = self._checkpointer.load_checkpoint(
              ckpt_iteration)
          print ('Found checkpoint to load at....', ckpt_iteration)
          break  # Success!
        except (EOFError, pickle.UnpicklingError) as e:
          if i == 99:
            raise RuntimeError(
                f'Unable to load checkpoint {ckpt} after 10 tries.'.format(
                    ckpt=ckpt_iteration)) from e
          logging.warning('Unable to load checkpoint, trying again after 10s')
          time.sleep(10)

      if self._agent.unbundle(
          self._checkpoint_dir, ckpt_iteration, experiment_data):
        if experiment_data is not None:
          assert 'logs' in experiment_data
          assert 'current_iteration' in experiment_data
          self._logger.data = experiment_data['logs']
          # self._start_iteration = experiment_data['current_iteration'] + 1
        # logging.info('Reloaded checkpoint and will start from iteration %d',
        #              self._start_iteration)
        print ('Starting evaluation from checkpoint ',
               ckpt_iteration, experiment_data['current_iteration'])
        checkpoint_found = True

    return checkpoint_found

  def _clear_params(self,):
    """Clear parameters that are passed in by dropping ref to self._agent, and
       recreating one."""
    print ('Clearing params..................................')
    del self._agent
    if len(self.full_games) > len(self._game_names) and len(self._game_names) == 1:
      self.current_game_index = self.full_games.index(self._game_names[0])
      self._agent = self._create_agent_fn(
          None, self._environment,
          summary_writer=None,
          game_names=self._game_names,
          use_single_game_action_space=self._use_single_game_action_space,
          override_game_index=self.current_game_index,
          override_num_games=len(self.full_games))
    else:
      self._agent = self._create_agent_fn(
          None, self._environment,
          summary_writer=None,
          game_names=self._game_names,
          use_single_game_action_space=self._use_single_game_action_space)
    print ('Cleared params......................................')


  def run_experiment(self):
    """Runs a full experiment, spread over multiple iterations."""
    logging.info('Beginning evaluation...')
    print ('Beginning evaluation.....' , self._eval_job_index)

    if self._num_iterations <= self._start_iteration:
      logging.warning('num_iterations (%d) < start_iteration(%d)',
                      self._num_iterations, self._start_iteration)
      return

    for iteration in range(self._start_iteration, self._num_iterations,
                           self._total_eval_jobs):
      # self._clear_params()
      statistics = self._run_one_iteration(iteration)
      self._log_experiment(iteration, statistics)
      self._checkpoint_eval_experiment(iteration)
    self._summary_writer.flush()


  def run_eval_experiment(self):
    """Run an evaluation loop over multiple iterations."""
    logging.info('Beginning evaluation...')
    print ('Beginning evaluation.....' , self._eval_ckpt_number)

    if self._eval_ckpt_number is not None:
      self._start_iteration = self._eval_ckpt_number

    if self._num_iterations <= self._start_iteration:
      logging.warning('num_iterations (%d) < start_iteration(%d)',
                      self._num_iterations, self._start_iteration)
      return

    statistics = self._run_one_iteration(self._start_iteration)
    self._log_experiment(self._start_iteration, statistics)
    self._summary_writer.flush()


  def _log_experiment(self, iteration, statistics):
    self._logger[str(self._game_names[0]) + '_iteration_{:d}'.format(iteration)] = statistics
    if iteration % self._log_every_n == 0:
      self._logger.log_to_file(self._logging_file_prefix, iteration)


  def _checkpoint_eval_experiment(self, iteration):
    """Checkpoint experiment data.

    Args:
      iteration: int, iteration number for checkpointing.
    """
    experiment_data = self._agent.bundle_and_checkpoint(self._checkpoint_dir,
                                                        iteration,
                                                        no_buffer=True)

    _checkpointer = checkpointer.Checkpointer(
        self._checkpoint_dir,
        checkpoint_file_prefix='eval_ckpt_' + str(self._eval_job_index) + '_' + self._game_name_to_append,
        sentinel_file_identifier='checkpoint_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append,
        checkpoint_duration=self._total_eval_jobs * 4)

    if experiment_data:
      experiment_data['current_iteration'] = iteration
      experiment_data['log_eval_' + str(self._eval_job_index) + '_' + self._game_name_to_append] = self._logger.data
      _checkpointer.save_checkpoint(iteration, experiment_data)


@gin.configurable
class MultiTaskFixedReplayRunner(MultiTaskRunner):
  """Object that handles running Dopamine experiments with
     multi-task fixed replay buffer."""

  def __init__(self, *args,
              init_ckpt_number=None,
              init_ckpt_dir=None, **kwargs):
    super(MultiTaskFixedReplayRunner, self).__init__(
          *args, only_train=True, **kwargs)
    self._only_train = True

    # If init ckpt number and dir are not None, then we load some initial ckpts
    if init_ckpt_number is not None and init_ckpt_dir is not None:
      print ('Init ckpt, and number are not None: ', init_ckpt_number, init_ckpt_dir)
      self._init_model_dir = os.path.join(init_ckpt_dir, 'checkpoints')
      if self._start_iteration == 0:
        # Making sure that loading ckpt doesn't already load something
        self._load_other_ckpt(self._init_model_dir,
                              self._checkpoint_file_prefix,
                              init_ckpt_number)
        print ('Loaded from start iteration from another run....')

  def _load_other_ckpt(self, ckpt_dir, ckpt_file_prefix, ckpt_number):
    print ('Starting to load other ckpt.....')
    _checkpointer = checkpointer.Checkpointer(ckpt_dir,
                                              ckpt_file_prefix)
    self._start_iteration = 0
    # Check if checkpoint exists. Note that the existence of checkpoint 0 means
    # that we have finished iteration 0 (so we will start from iteration 1).
    latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(ckpt_dir)
    if latest_checkpoint_version >= 0 and ckpt_number <= latest_checkpoint_version:
      experiment_data = _checkpointer.load_checkpoint(ckpt_number)
      if self._agent.unbundle_without_opt(
          ckpt_dir, ckpt_number, experiment_data):
        if experiment_data is not None:
          assert 'logs' in experiment_data
          assert 'current_iteration' in experiment_data
          self._logger.data = experiment_data['logs']
          self._start_iteration = experiment_data['current_iteration'] + 1
        logging.info('[Other] Reloaded checkpoint and will start from iteration %d',
                     self._start_iteration)

  def _run_train_phase(self):
    """Run training phase."""
    self._agent.eval_mode = False
    start_time = time.time()

    for i in range(self._training_steps):
      if i % 100 == 0:
        # We use sys.stdout.write instead of logging so as to flush frequently
        # without generating a line break.
        sys.stdout.write('Training step: {}/{}\r'.format(
            i, self._training_steps))
        sys.stdout.flush()

      self._agent.train_step()

    time_delta = time.time() - start_time
    logging.info('Average training steps per second: %.2f',
                 self._training_steps / time_delta)

  def run_experiment(self):
    """Runs a full experiment, spread over multiple iterations."""
    logging.info('Beginning training...')
    if self._num_iterations <= self._start_iteration:
      logging.warning('num_iterations (%d) < start_iteration(%d)',
                      self._num_iterations, self._start_iteration)
      return

    for iteration in range(self._start_iteration, self._num_iterations):
      statistics = self._run_one_iteration(iteration)
      if jax.process_index() == 0:
        # Only log and save checkpoints when process_index == 0
        self._log_experiment(iteration, statistics)
        self._checkpoint_experiment(iteration)
    self._summary_writer.flush()

  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    logging.info('Starting iteration %d', iteration)
    # Reload the replay buffer at every iteration
    self._agent.reload_data()
    self._run_train_phase()

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)

    self._save_tensorboard_summaries(iteration, num_episodes_eval,
                                     average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration, num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.

    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    if not self._only_train:
      with self._summary_writer.as_default():
        for idx in range(len(self._game_names)):
          tf.summary.scalar(
              'Eval_' + str(self._game_names[idx]) + '/NumEpisodes',
              num_episodes_eval[idx],
              step=iteration)
          tf.summary.scalar(
              'Eval_' + str(self._game_names[idx]) + '/AverageReturns',
              average_reward_eval[idx],
              step=iteration)