#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import pickle
from typing import Union, List

import numpy as np

from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.head_parameters import DNDQHeadParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, AgentParameters

from rl_coach.core_types import RunPhase, EnvironmentSteps, Episode, StateType
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.logger import screen
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters, MemoryGranularity
from rl_coach.schedules import ConstantSchedule


class NECNetworkParameters(NetworkParameters):
    def __init__(self):
        super().__init__()
        self.input_embedders_parameters = {'observation': InputEmbedderParameters()}
        self.middleware_parameters = FCMiddlewareParameters()
        self.heads_parameters = [DNDQHeadParameters()]
        self.optimizer_type = 'Adam'
        self.should_get_softmax_probabilities = False


class NECAlgorithmParameters(AlgorithmParameters):
    """
    :param dnd_size: (int)
        Defines the number of transitions that will be stored in each one of the DNDs. Note that the total number
        of transitions that will be stored is dnd_size x num_actions.

    :param l2_norm_added_delta: (float)
        A small value that will be added when calculating the weight of each of the DND entries. This follows the
        :math:`\delta` patameter defined in the paper.

    :param new_value_shift_coefficient: (float)
        In the case where a ew embedding that was added to the DND was already present, the value that will be stored
        in the DND is a mix between the existing value and the new value. The mix rate is defined by
        new_value_shift_coefficient.

    :param number_of_knn: (int)
        The number of neighbors that will be retrieved for each DND query.

    :param DND_key_error_threshold: (float)
        When the DND is queried for a specific embedding, this threshold will be used to determine if the embedding
        exists in the DND, since exact matches of embeddings are very rare.

    :param propagate_updates_to_DND: (bool)
        If set to True, when the gradients of the network will be calculated, the gradients will also be
        backpropagated through the keys of the DND. The keys will then be updated as well, as if they were regular
        network weights.

    :param n_step: (int)
        The bootstrap length that will be used when calculating the state values to store in the DND.

    :param bootstrap_total_return_from_old_policy: (bool)
        If set to True, the bootstrap that will be used to calculate each state-action value, is the network value
        when the state was first seen, and not the latest, most up-to-date network value.
    """
    def __init__(self):
        super().__init__()
        self.dnd_size = 500000
        self.l2_norm_added_delta = 0.001
        self.new_value_shift_coefficient = 0.1
        self.number_of_knn = 50
        self.DND_key_error_threshold = 0
        self.num_consecutive_playing_steps = EnvironmentSteps(4)
        self.propagate_updates_to_DND = False
        self.n_step = 100
        self.bootstrap_total_return_from_old_policy = True


class NECMemoryParameters(EpisodicExperienceReplayParameters):
    def __init__(self):
        super().__init__()
        self.max_size = (MemoryGranularity.Transitions, 100000)


class NECAgentParameters(AgentParameters):
    def __init__(self):
        super().__init__(algorithm=NECAlgorithmParameters(),
                         exploration=EGreedyParameters(),
                         memory=NECMemoryParameters(),
                         networks={"main": NECNetworkParameters()})
        self.exploration.epsilon_schedule = ConstantSchedule(0.1)
        self.exploration.evaluation_epsilon = 0.01

    @property
    def path(self):
        return 'rl_coach.agents.nec_agent:NECAgent'


# Neural Episodic Control - https://arxiv.org/pdf/1703.01988.pdf
class NECAgent(ValueOptimizationAgent):
    def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
        super().__init__(agent_parameters, parent)
        self.current_episode_state_embeddings = []
        self.training_started = False
        self.current_episode_buffer = \
            Episode(discount=self.ap.algorithm.discount,
                    n_step=self.ap.algorithm.n_step,
                    bootstrap_total_return_from_old_policy=self.ap.algorithm.bootstrap_total_return_from_old_policy)

    def learn_from_batch(self, batch):
        if not self.networks['main'].online_network.output_heads[0].DND.has_enough_entries(self.ap.algorithm.number_of_knn):
            return 0, [], 0
        else:
            if not self.training_started:
                self.training_started = True
                screen.log_title("Finished collecting initial entries in DND. Starting to train network...")

        network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()

        TD_targets = self.networks['main'].online_network.predict(batch.states(network_keys))
        bootstrapped_return_from_old_policy = batch.n_step_discounted_rewards()
        #  only update the action that we have actually done in this transition
        for i in range(batch.size):
            TD_targets[i, batch.actions()[i]] = bootstrapped_return_from_old_policy[i]

        # set the gradients to fetch for the DND update
        fetches = []
        head = self.networks['main'].online_network.output_heads[0]
        if self.ap.algorithm.propagate_updates_to_DND:
            fetches = [head.dnd_embeddings_grad, head.dnd_values_grad, head.dnd_indices]

        # train the neural network
        result = self.networks['main'].train_and_sync_networks(batch.states(network_keys), TD_targets, fetches)

        total_loss, losses, unclipped_grads = result[:3]

        # update the DND keys and values using the extracted gradients
        if self.ap.algorithm.propagate_updates_to_DND:
            embedding_gradients = np.swapaxes(result[-1][0], 0, 1)
            value_gradients = np.swapaxes(result[-1][1], 0, 1)
            indices = np.swapaxes(result[-1][2], 0, 1)
            head.DND.update_keys_and_values(batch.actions(), embedding_gradients, value_gradients, indices)

        return total_loss, losses, unclipped_grads

    def act(self):
        if self.phase == RunPhase.HEATUP:
            # get embedding in heatup (otherwise we get it through get_prediction)
            embedding = self.networks['main'].online_network.predict(
                self.prepare_batch_for_inference(self.curr_state, 'main'),
                outputs=self.networks['main'].online_network.state_embedding)
            self.current_episode_state_embeddings.append(embedding.squeeze())

        return super().act()

    def get_all_q_values_for_states(self, states: StateType, additional_outputs: List = None):
        # we need to store the state embeddings regardless if the action is random or not
        return self.get_prediction_and_update_embeddings(states)

    def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
        # get the actions q values and the state embedding
        embedding, actions_q_values, softmax_probabilities = self.networks['main'].online_network.predict(
            self.prepare_batch_for_inference(states, 'main'),
            outputs=[self.networks['main'].online_network.state_embedding,
                     self.networks['main'].online_network.output_heads[0].output,
                     self.networks['main'].online_network.output_heads[0].softmax]
        )
        if self.phase != RunPhase.TEST:
            # store the state embedding for inserting it to the DND later
            self.current_episode_state_embeddings.append(embedding.squeeze())
        actions_q_values = actions_q_values[0][0]
        return actions_q_values, softmax_probabilities

    def get_prediction_and_update_embeddings(self, states):
        # get the actions q values and the state embedding
        embedding, actions_q_values = self.networks['main'].online_network.predict(
            self.prepare_batch_for_inference(states, 'main'),
            outputs=[self.networks['main'].online_network.state_embedding,
                     self.networks['main'].online_network.output_heads[0].output]
        )
        if self.phase != RunPhase.TEST:
            # store the state embedding for inserting it to the DND later
            self.current_episode_state_embeddings.append(embedding[0].squeeze())
        actions_q_values = actions_q_values[0][0]
        return actions_q_values

    def reset_internal_state(self):
        super().reset_internal_state()
        self.current_episode_state_embeddings = []
        self.current_episode_buffer = \
            Episode(discount=self.ap.algorithm.discount,
                    n_step=self.ap.algorithm.n_step,
                    bootstrap_total_return_from_old_policy=self.ap.algorithm.bootstrap_total_return_from_old_policy)

    def handle_episode_ended(self):
        super().handle_episode_ended()

        # get the last full episode that we have collected
        episode = self.call_memory('get_last_complete_episode')
        if episode is not None and self.phase != RunPhase.TEST:
            assert len(self.current_episode_state_embeddings) == episode.length()
            discounted_rewards = episode.get_transitions_attribute('n_step_discounted_rewards')
            actions = episode.get_transitions_attribute('action')
            self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
                                                                         actions, discounted_rewards)

    def save_checkpoint(self, checkpoint_prefix):
        super().save_checkpoint(checkpoint_prefix)
        with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_prefix) + '.dnd'), 'wb') as f:
            pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
