# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Various masked_softmax implementations, both in numpy and tensorflow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
np.set_printoptions(suppress=True)
np.seterr(divide='ignore')


def np_masked_softmax(logits, legal_actions_mask):
    """Returns the softmax over the valid actions defined by `legal_actions_mask`.

    Args:
      logits: A tensor [..., num_actions] (e.g. [num_actions] or [B, num_actions])
        representing the logits to mask.
      legal_actions_mask: The legal action mask, same shape as logits. 1 means
        it's a legal action, 0 means it's illegal.
    """
    masked_logits = logits + np.log(legal_actions_mask)
    max_logit = np.amax(masked_logits, axis=-1, keepdims=True)
    exp_logit = np.exp(masked_logits - max_logit)
    return exp_logit / np.sum(exp_logit, axis=-1, keepdims=True)


def np_softmax(logits):
    max_logit = np.amax(logits, axis=-1, keepdims=True)
    exp_logit = np.exp(logits - max_logit)
    return exp_logit / np.sum(exp_logit, axis=-1, keepdims=True)


def extract_trainable_variables(models):
    trainable_variables = [variables
                           for variables_list in map(lambda n: n.trainable_variables, models)
                           for variables in variables_list]
    return trainable_variables
