#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pickle
from typing import Any, List, Dict

import tensorflow as tf
import numpy as np

from rl_coach.saver import Saver


class GlobalVariableSaver(Saver):
    def __init__(self, name):
        self._names = [name]
        # if graph is finalized, savers must have already already been added. This happens
        # in the case of a MonitoredSession
        # self._variables = tf.global_variables()
        self._variables = tf.trainable_variables()
        self._teacher_variables = [v for v in self._variables if ('/teacher' in v.name and name.split('/')[0] in v.name)]
        self._online_variables = [v for v in self._variables if ('/online' in v.name and name.split('/')[0] in v.name)]

        # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
        # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
        # self._variables = [v for v in self._variables if "/target" not in v.name]
        self._variables = [v for v in self._variables if ('/target' not in v.name and name.split('/')[0] in v.name)]

        # Using a placeholder to update the variable during restore to avoid memory leak.
        # Ref: https://github.com/tensorflow/tensorflow/issues/4151
        self._variable_placeholders = []
        self._variable_update_ops = []
        for v in self._variables:
            variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
            self._variable_placeholders.append(variable_placeholder)
            self._variable_update_ops.append(v.assign(variable_placeholder))

        self._saver = tf.train.Saver(self._variables, max_to_keep=None)

        self._teacher_variable_placeholders = []
        self._teacher_variable_update_ops = []
        for v in self._teacher_variables:
            variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
            self._teacher_variable_placeholders.append(variable_placeholder)
            self._teacher_variable_update_ops.append(v.assign(variable_placeholder))

        self._online_variable_placeholders = []
        self._online_variable_update_ops = []
        for v in self._online_variables:
            variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
            self._online_variable_placeholders.append(variable_placeholder)
            self._online_variable_update_ops.append(v.assign(variable_placeholder))

    @property
    def path(self):
        """
        Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able.
        """
        return ""  # use empty string for global file

    def save(self, sess: None, save_path: str) -> List[str]:
        """
        Save to save_path
        :param sess: active session
        :param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path)
        :return: list of all saved paths
        """
        save_path = self._saver.save(sess, save_path)
        return [save_path]

    def to_arrays(self, session: Any) -> Dict[str, np.ndarray]:
        """
        Save to dictionary of arrays
        :param sess: active session
        :return: dictionary of arrays
        """
        return {
            k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables))
        }

    def from_arrays(self, session: Any, tensors: Any):
        """
        Restore from restore_path
        :param sess: active session for session-based frameworks (e.g. TF)
        :param tensors: {name: array}
        """
        # if variable was saved using global network, re-map it to online
        # network
        # TODO: Can this be more generic so that `global/` and `online/` are not
        # hardcoded here?
        if isinstance(tensors, dict):
            tensors = tensors.items()

        variables = {k.replace("global/", "online/"): v for k, v in tensors}

        # Assign all variables using placeholder
        placeholder_dict = {
            ph: variables[v.name.split(":")[0]]
            for ph, v in zip(self._variable_placeholders, self._variables)
        }
        session.run(self._variable_update_ops, placeholder_dict)

    def to_string(self, session: Any) -> str:
        """
        Save to byte string
        :param session: active session
        :return: serialized byte string
        """
        return pickle.dumps(self.to_arrays(session), protocol=-1)

    def from_string(self, session: Any, string: str):
        """
        Restore from byte string
        :param session: active session
        :param string: byte string to restore from
        """
        self.from_arrays(session, pickle.loads(string))

    def _read_tensors(self, restore_path: str):
        """
        Load tensors from a checkpoint
        :param restore_path: full path to load checkpoint from.
        """
        # We don't use saver.restore() because checkpoint is loaded to online
        # network, but if the checkpoint is from the global network, a namespace
        # mismatch exists and variable name must be modified before loading.
        reader = tf.contrib.framework.load_checkpoint(restore_path)
        for var_name, _ in reader.get_variable_to_shape_map().items():
            yield var_name, reader.get_tensor(var_name)

    def restore(self, sess: Any, restore_path: str):
        """
        Restore from restore_path
        :param sess: active session for session-based frameworks (e.g. TF)
        :param restore_path: full path to load checkpoint from.
        """
        self.from_arrays(sess, self._read_tensors(restore_path))

    def from_arrays_teacher(self, session: Any, tensors: Any):
        if isinstance(tensors, dict):
            tensors = tensors.items()

        variables = {k.replace("online/", "teacher/"): v for k, v in tensors}

        # Assign all variables using placeholder
        placeholder_dict = {
            ph: variables[v.name.split(":")[0]]
            for ph, v in zip(self._teacher_variable_placeholders, self._teacher_variables)
        }
        session.run(self._teacher_variable_update_ops, placeholder_dict)

    def from_arrays_online(self, session: Any, tensors: Any):
        if isinstance(tensors, dict):
            tensors = tensors.items()

        variables = {k.replace("global/", "online/"): v for k, v in tensors}

        # Assign all variables using placeholder
        placeholder_dict = {
            ph: variables[v.name.split(":")[0]]
            for ph, v in zip(self._online_variable_placeholders, self._online_variables)
        }
        session.run(self._online_variable_update_ops, placeholder_dict)

    def restore_teacher(self, sess: Any, restore_path: str):
        """
        Restore from restore_path
        :param sess: active session for session-based frameworks (e.g. TF)
        :param restore_path: full path to load checkpoint from.
        """
        self.from_arrays_teacher(sess, self._read_tensors(restore_path))

    def restore_online(self, sess: Any, restore_path: str):
        """
        Restore from restore_path
        :param sess: active session for session-based frameworks (e.g. TF)
        :param restore_path: full path to load checkpoint from.
        """
        self.from_arrays_online(sess, self._read_tensors(restore_path))

    def merge(self, other: "Saver"):
        """
        Merge other saver into this saver
        :param other: saver to be merged into self
        """
        assert isinstance(other, GlobalVariableSaver)
        self._names.extend(other._names)
        # There is nothing else to do because variables must already be part of
        # the global collection.
