# coding=utf-8
# Copyright 2022 The Multi Task Atari Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpoint loading utilities."""

import os

from absl import logging
from dopamine.discrete_domains import checkpointer
import flax
import jax
import jax.numpy as jnp
import tensorflow as tf


def get_bundle_dictionary(checkpoint_dir,
                          checkpoint_number=None,
                          checkpoint_file_prefix='ckpt'):
  """Returns bundle_dictionary from a checkpoint_dir."""
  model_checkpointer = checkpointer.Checkpointer(
      checkpoint_dir, checkpoint_file_prefix)
  # Check if checkpoint exists. Note that the existence of checkpoint 0 means
  # that we have finished iteration 0 (so we will start from iteration 1).
  latest_checkpoint_number = checkpointer.get_latest_checkpoint_number(
      checkpoint_dir)
  if checkpoint_number is None:
    checkpoint_number = latest_checkpoint_number
  assert checkpoint_number <= latest_checkpoint_number, (
      f"checkpoint_number {checkpoint_number} doesn't exist")
  experiment_data = None
  if checkpoint_number >= 0:
    if 'tf' in checkpoint_file_prefix:
      experiment_data = create_dqn_checkpoint_data(
          checkpoint_dir, checkpoint_file_prefix,
          checkpoint_number)
    else:
      logging.info('Loaded experiment data from %s and checkpoint_number %d',
                   checkpoint_dir, checkpoint_number)
      experiment_data = model_checkpointer.load_checkpoint(
          checkpoint_number)
  return experiment_data


def load_tf_nature_dqn_weights(checkpoint_path,
                               prefix='Online'):
  """Load the TF NatureDQNNetwork weights and convert to a JAX array."""

  ckpt_reader = tf.train.load_checkpoint(checkpoint_path)
  jax_to_tf_layer_mapping = {
      'Conv_0': 'Conv',
      'Conv_1': 'Conv_1',
      'Conv_2': 'Conv_2',
      'Dense_0': 'fully_connected',
      'Dense_1': 'fully_connected_1',
  }
  params = {}
  for jax_layer, tf_layer in jax_to_tf_layer_mapping.items():
    params[jax_layer] = {
        'bias': ckpt_reader.get_tensor(f'{prefix}/{tf_layer}/biases'),
        'kernel': ckpt_reader.get_tensor(f'{prefix}/{tf_layer}/weights'),
    }
  jax_params = jax.tree_map(jnp.asarray, {'params': params})
  return flax.core.FrozenDict(jax_params)


def create_dqn_checkpoint_data(checkpoint_path, checkpoint_file_prefix,
                               iteration_number, auxiliary_info=True):
  """Loads tf Nature DQN weights and creates a dict to restore a JAX agent."""
  bundle_dictionary = {}
  tf_checkpoint_path = os.path.join(
      checkpoint_path, f'{checkpoint_file_prefix}-{iteration_number}')
  if auxiliary_info:
    bundle_dictionary['current_iteration'] = iteration_number
    # This assumes that the agent was trained using standard Dopamine params.
    bundle_dictionary['training_steps'] = (iteration_number + 1) * 250000
  bundle_dictionary['online_params'] = load_tf_nature_dqn_weights(
      tf_checkpoint_path, 'Online')
  bundle_dictionary['target_params'] = load_tf_nature_dqn_weights(
      tf_checkpoint_path, 'Target')
  return bundle_dictionary
