# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic models for testing simple tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import pickle

from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
from tensor2tensor.layers import common_video
from tensor2tensor.models.video import base
from tensor2tensor.models.video import basic_deterministic_params  # pylint: disable=unused-import
from tensor2tensor.utils import registry
from tensorflow.python.keras.utils import conv_utils

import tensorflow as tf


@registry.register_model
class NextFrameBasicDeterministic(base.NextFrameBase):
  """Basic next-frame model, may take actions and predict rewards too."""
  def __init__(self, *args, **kwargs):
    self.enable_dropout_cache = False
    self.dump_dropout_masks = False
    super(NextFrameBasicDeterministic, self).__init__(*args, **kwargs)
    self.zero_noise = False
    if self.mode_for_the_model == tf.estimator.ModeKeys.TRAIN :
      self.dump_dropout_masks = True
      self.network_tensors_for_dropout = []
      self.dropout_cache = {}
    if  self.mode_for_the_model == tf.estimator.ModeKeys.PREDICT :
      epoch = self._original_hparams.epoch
      if epoch >= 30 :
        self.zero_noise = True
      else :
        self.load_dropout_masks_gaussian()
        self.enable_dropout_cache = True

  @property
  def is_recurrent_model(self):
    return False

  def inject_latent(self, layer, inputs, target, action):
    del inputs, target, action
    return layer, 0.0

  def apply_dropout(self, inp, keep_prob) :
    # If cache is enabled and mask exists for the layer mask it and return
    if self.dump_dropout_masks :
      self.network_tensors_for_dropout.append((inp, keep_prob))
      return tf.nn.dropout(inp, keep_prob)
    elif self.enable_dropout_cache :
      keyname = get_unique_name(inp.name)
      mask = self.dropout_cache.get(keyname)
      return inp * mask 
    else :
      return tf.nn.dropout(inp, keep_prob)

  def apply_gaussian_dropout(self, inp, kernel_dim) :
    if self.dump_dropout_masks :
      num_filters = common_layers.shape_list(inp)[-1] # inp = out
      self.network_tensors_for_dropout.append((inp, kernel_dim, num_filters))
      return self.gaussian_dropout_conv(inp, kernel_dim)
    elif self.enable_dropout_cache :
      keyname = get_unique_name(inp.name)
      noise = self.dropout_cache.get(keyname)
      return self.gaussian_conv(inp, noise, kernel_dim)
    else :
      return self.gaussian_dropout_conv(inp, kernel_dim)


  def apply_gaussian_dropout_extra_conv(self, rew_inp, tran_inp, kernel_dim, strides, num_filters, noisy_kernel_dim, transpose=False) :
    if self.dump_dropout_masks :
      self.network_tensors_for_dropout.append((rew_inp, noisy_kernel_dim, common_layers.shape_list(tran_inp)[-1]))
      return self.transition_reward_extra_noisy_conv_for_reward(rew_inp, tran_inp, kernel_dim, strides, num_filters, noisy_kernel_dim, transpose=transpose)
    elif self.enable_dropout_cache :
      keyname = get_unique_name(rew_inp.name)
      noise = self.dropout_cache.get(keyname)
      return self.transition_reward_extra_noisy_conv_for_reward(rew_inp, tran_inp, kernel_dim, strides, num_filters, noisy_kernel_dim, noise, transpose=transpose)
    else :
      return self.transition_reward_extra_noisy_conv_for_reward(rew_inp, tran_inp, kernel_dim, strides, num_filters, noisy_kernel_dim, transpose=transpose)

  def apply_gaussian_dropout_only_reward_noisy(self, rew_inp, tran_inp, kernel_dim, strides, num_filters, transpose=False) :
    if self.dump_dropout_masks :
      self.network_tensors_for_dropout.append((rew_inp, kernel_dim, num_filters))
      return self.transition_reward_shared_conv(rew_inp, tran_inp, kernel_dim, strides, num_filters, transpose=transpose)
    elif self.enable_dropout_cache :
      keyname = get_unique_name(rew_inp.name)
      noise = self.dropout_cache.get(keyname)
      return self.transition_reward_shared_conv(rew_inp, tran_inp, kernel_dim, strides, num_filters, noise, transpose=transpose)
    else :
      return self.transition_reward_shared_conv(rew_inp, tran_inp, kernel_dim, strides, num_filters, transpose=transpose)

  def apply_gaussian_weighted_channel_dropout(self, inp, kernel_dim, strides) :
    if self.dump_dropout_masks :
      num_filters = common_layers.shape_list(inp)[-1] # inp = out
      self.network_tensors_for_dropout.append((inp, kernel_dim, num_filters))
      return self.gaussian_dropout_conv_weighted_channel(inp, kernel_dim, strides)
    elif self.enable_dropout_cache :
      keyname = get_unique_name(inp.name)
      noise = self.dropout_cache.get(keyname)
      return self.gaussian_conv_weighted_channel(inp, noise, kernel_dim, strides)
    else :
      return self.gaussian_dropout_conv_weighted_channel(inp, kernel_dim, strides)



  def middle_network(self, layer, internal_states):
    # Run a stack of convolutions.
    x = layer
    kernel1 = (3, 3)
    filters = common_layers.shape_list(x)[-1]
    for i in range(self.hparams.num_hidden_layers):
      with tf.variable_scope("layer%d" % i):
        # y = tf.nn.dropout(x, 1.0 - self.hparams.residual_dropout)
        # y = self.apply_dropout(x,1.0 - self.hparams.residual_dropout)
        y=x
        y = tf.layers.conv2d(y, filters, kernel1, activation=common_layers.belu,
                             strides=(1, 1), padding="SAME")
        if i == 0:
          x = y
        else:
          x = common_layers.layer_norm(x + y)
    return x, internal_states

  def update_internal_states_early(self, internal_states, frames):
    """Update the internal states early in the network if requested."""
    del frames
    return internal_states

  def next_frame(self, frames, actions, rewards, target_frame,
                 internal_states, video_extra):
    del rewards, video_extra

    hparams = self.hparams
    filters = hparams.hidden_size
    kernel2 = (4, 4)
    action = actions[-1]

    # Stack the inputs.
    if internal_states is not None and hparams.concat_internal_states:
      # Use the first part of the first internal state if asked to concatenate.
      batch_size = common_layers.shape_list(frames[0])[0]
      internal_state = internal_states[0][0][:batch_size, :, :, :]
      stacked_frames = tf.concat(frames + [internal_state], axis=-1)
    else:
      stacked_frames = tf.concat(frames, axis=-1)
    inputs_shape = common_layers.shape_list(stacked_frames)

    # Update internal states early if requested.
    if hparams.concat_internal_states:
      internal_states = self.update_internal_states_early(
          internal_states, frames)

    # Using non-zero bias initializer below for edge cases of uniform inputs.
    x = tf.layers.dense(
        stacked_frames, filters, name="inputs_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Down-stride.
    layer_inputs = [x]
    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("downstride%d" % i):
        layer_inputs.append(x)
        # x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        # x = self.apply_dropout(x, 1.0 - self.hparams.dropout)
        x = common_layers.make_even_size(x)
        if i < hparams.filter_double_steps:
          filters *= 2
        x = common_attention.add_timing_signal_nd(x)
        x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu,
                             strides=(2, 2), padding="SAME")
        x = common_layers.layer_norm(x)

    if self.has_actions:
      with tf.variable_scope("policy"):
        x_flat = tf.layers.flatten(x)
        policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions)
        value_pred = tf.layers.dense(x_flat, 1)
        value_pred = tf.squeeze(value_pred, axis=-1)
    else:
      policy_pred, value_pred = None, None

    # Add embedded action if present.
    if self.has_actions:
      x = common_video.inject_additional_input(
          x, action, "action_enc", hparams.action_injection)

    # Inject latent if present. Only for stochastic models.
    x, extra_loss = self.inject_latent(x, frames, target_frame, action)

    x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    x, internal_states = self.middle_network(x, internal_states)

    # Up-convolve.
    layer_inputs = list(reversed(layer_inputs))

    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("upstride%d" % i):
        # x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        # x = self.apply_dropout(x, 1.0 - self.hparams.dropout)
        if i <= 3 : 
          if self.has_actions:
            x = common_video.inject_additional_input(
              x, action, "action_enc", hparams.action_injection)
          if i >= hparams.num_compress_steps - hparams.filter_double_steps:
            filters //= 2
          x = tf.layers.conv2d_transpose(
            x, filters, kernel2, activation=common_layers.belu,
            strides=(2, 2), padding="SAME")
          y = layer_inputs[i]
          shape = common_layers.shape_list(y)
          x = x[:, :shape[1], :shape[2], :]
          x = common_layers.layer_norm(x + y)
          x = common_attention.add_timing_signal_nd(x)
          tran_inp = x
          rew_inp = x
        else :
          if self.has_actions:
            tran_inp = common_video.inject_additional_input(
              tran_inp, action, "action_enc", hparams.action_injection)
            rew_inp = common_video.inject_additional_input(
              rew_inp, action, "action_enc", hparams.action_injection)   
          if i >= hparams.num_compress_steps - hparams.filter_double_steps:
            filters //= 2
          
          if i in [4,5] :
            rew_inp = self.apply_gaussian_weighted_channel_dropout(rew_inp, kernel_dim = [1,1], strides=[1,1,1,1])
            rew_inp, tran_inp = self.transition_reward_no_noise_shared_conv(rew_inp, tran_inp, 
                                                                   kernel_dim = list(kernel2), strides = [1,2,2,1], 
                                                                   num_filters = filters, transpose=True)
          else : 
            rew_inp, tran_inp = self.transition_reward_no_noise_shared_conv(rew_inp, tran_inp, 
                                                                   kernel_dim = list(kernel2), strides = [1,2,2,1], 
                                                                   num_filters = filters, transpose=True)
          y = layer_inputs[i]
          shape = common_layers.shape_list(y)
          rew_inp = rew_inp[:, :shape[1], :shape[2], :]
          rew_inp = common_layers.layer_norm(rew_inp + y)
          rew_inp = common_attention.add_timing_signal_nd(rew_inp)
          tran_inp = tran_inp[:, :shape[1], :shape[2], :]
          tran_inp = common_layers.layer_norm(tran_inp + y)
          tran_inp = common_attention.add_timing_signal_nd(tran_inp)
          
    # Cut down to original size.
    rew_inp = rew_inp[:, :inputs_shape[1], :inputs_shape[2], :]
    tran_inp = tran_inp[:, :inputs_shape[1], :inputs_shape[2], :]
    x_fin = tf.reduce_mean(rew_inp, axis=[1, 2], keepdims=True)
    if self.is_per_pixel_softmax:
      x = tf.layers.dense(tran_inp, hparams.problem.num_channels * 256, name="logits")
    else:
      x = tf.layers.dense(tran_inp, hparams.problem.num_channels, name="logits")

    reward_pred = None
    if self.has_rewards:
      # Reward prediction based on middle and final logits.
      reward_pred = tf.concat([x_mid, x_fin], axis=-1)
      reward_pred = tf.nn.relu(tf.layers.dense(
          reward_pred, 128, name="reward_pred"))
      reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
      reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

    self.generate_and_dump_gaussian_masks_for_simulation()
    return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states


  def generate_and_dump_masks_for_simulation(self) :
    if self.dump_dropout_masks :
      for (tensor,keep_prob) in self.network_tensors_for_dropout :
        # keep_prob is 1-rate; older version of tf
        shape_list = tensor.get_shape().as_list()[1:]
        mask = (1/keep_prob) * (np.random.random(shape_list) <= (keep_prob))
        self.dropout_cache[tensor.name] = mask
      with open(os.path.join(self._original_hparams.model_dir,'masks'), 'wb+') as f :
        pickle.dump(self.dropout_cache,f)
    self.dump_dropout_masks = False # one mask dumped; dont do for future calls

  def generate_and_dump_gaussian_masks_for_simulation(self) :
    if self.dump_dropout_masks :
      for (tensor,kernel_dim, num_filters) in self.network_tensors_for_dropout :
        # keep_prob is 1-rate; older version of tf
        filter_shape = kernel_dim + [common_layers.shape_list(tensor)[-1], num_filters]   
        noise_mask = sample_noise_cpu(filter_shape)
        self.dropout_cache[tensor.name] = noise_mask
      with open(os.path.join(self._original_hparams.model_dir,'masks'), 'wb+') as f :
        pickle.dump(self.dropout_cache,f)
    self.dump_dropout_masks = False # one mask dumped; dont do for future calls


  def load_dropout_masks(self) :
    self.dropout_cache = {}
    with open(os.path.join(self._original_hparams.model_dir,'masks'), 'rb+') as f :
      dropout_cache_numpy = pickle.load(f)
    for (name,tensor) in dropout_cache_numpy.items() :
      keyname = get_unique_name(name)
      self.dropout_cache[keyname] = tf.constant(np.expand_dims(tensor, 0), dtype=tf.float32)

  def load_dropout_masks_gaussian(self) :
    self.dropout_cache = {}
    with open(os.path.join(self._original_hparams.model_dir,'masks'), 'rb+') as f :
      dropout_cache_numpy = pickle.load(f)
    for (name,tensor) in dropout_cache_numpy.items() :
      keyname = get_unique_name(name)
      self.dropout_cache[keyname] = tf.constant(tensor, dtype=tf.float32)


  def gaussian_dropout_conv(self, input_layer, kernel_dim) :
    filter_shape = kernel_dim + [common_layers.shape_list(input_layer)[-1]]*2 
    if self.zero_noise :
      noise = tf.zeros(filter_shape)
    else:  
      noise = sample_noise(filter_shape)
    return self.gaussian_conv(input_layer, noise, kernel_dim)
    
  
  def gaussian_conv(self, input_layer, noise, kernel_dim, mu=None, sigma=None, strides=None,  transpose=False) :
    with tf.variable_scope("gaussian_dropout"):  
      # 1x1 convolution. So, filter height, width is 1,1 in_channels = out_channels
      filter_shape = kernel_dim + [common_layers.shape_list(input_layer)[-1]]*2  
      filter_mu = tf.get_variable("filter_mu", filter_shape) if mu is None else mu
      noise = tf.reshape(noise, shape=filter_mu.get_shape())
      filter_sigma = tf.get_variable("filter_sigma",  filter_shape) if sigma is None else sigma
      f_strides = [1,1,1,1] if strides is None else strides

      if mu is not None: 
        shape_of_mu = common_layers.shape_list(mu)
        num_filters = shape_of_mu[-2] if transpose else shape_of_mu[-1]
      else :
        num_filters = common_layers.shape_list(input_layer)[-1]
      noisy_filter = filter_mu * (tf.ones_like(filter_mu) + tf.multiply(filter_sigma, noise))
      
      if transpose :
        f = f = tf.nn.conv2d_transpose 
        args = {'value':input_layer, 'output_shape':get_output_shape(input_layer, kernel_dim, f_strides, num_filters)}
      else :
        f = tf.nn.conv2d
        args = {'input':input_layer} 
      gaussian_dropout_conv_layer = f(**args, 
                                filter = noisy_filter,
                                strides = f_strides, 
                                padding = 'SAME')
      return common_layers.belu(gaussian_dropout_conv_layer)


  def transition_reward_shared_conv(self, rew_inp, tran_inp, kernel_dim, strides, num_filters, noise = None, transpose=False) :
    filter_shape = kernel_dim + [common_layers.shape_list(rew_inp)[-1], num_filters]
    if transpose :
      f = tf.nn.conv2d_transpose
      args = {'value':tran_inp, 'output_shape':get_output_shape(tran_inp, kernel_dim, strides, num_filters)}
      filter_shape[2], filter_shape[3] = filter_shape[3], filter_shape[2]
    else :
      f = tf.nn.conv2d
      args = {'input':tran_inp} 

    if self.zero_noise :
      noise = tf.zeros(filter_shape)
    else:
      noise = sample_noise(filter_shape) if noise is None else noise
    with tf.variable_scope("conv_layer"):  
      filter_mu = tf.get_variable("filter_mu", filter_shape)
      filter_sigma = tf.get_variable("filter_sigma",  filter_shape)
      noise = tf.reshape(noise, shape=filter_mu.get_shape())
      reward_hidden_layer = self.gaussian_conv(rew_inp, noise, kernel_dim, filter_mu, filter_sigma, strides, transpose)
      transition_hidden_layer = f(**args, 
                                filter = filter_mu,
                                strides = strides, 
                                padding = 'SAME')
    return common_layers.belu(reward_hidden_layer), common_layers.belu(transition_hidden_layer)


  def transition_reward_extra_noisy_conv_for_reward(self, rew_inp, tran_inp, kernel_dim, strides, num_filters, noisy_kernel_dim, noise = None, transpose=False) :
    filter_shape = kernel_dim + [common_layers.shape_list(rew_inp)[-1],  num_filters]
    filter_shape_noisy = noisy_kernel_dim + [common_layers.shape_list(rew_inp)[-1]]*2
    if self.zero_noise :
      noise = tf.zeros(filter_shape_noisy)
    else:
      noise = sample_noise(filter_shape_noisy) if noise is None else noise

    with tf.variable_scope('gaussian_dropout') :
      filter_sigma = tf.get_variable("filter_sigma",  filter_shape_noisy)
      filter_mu = tf.get_variable("filter_mu",  filter_shape_noisy)
      noise = tf.reshape(noise, shape=filter_mu.get_shape())
      reward_gaussian_dropout_hidden = self.gaussian_conv(rew_inp, noise, noisy_kernel_dim, filter_mu, filter_sigma)

    if transpose :
      filter_shape[2], filter_shape[3] = filter_shape[3], filter_shape[2]
      f = tf.nn.conv2d_transpose
      rew_args = {'value':reward_gaussian_dropout_hidden, 'output_shape':get_output_shape(reward_gaussian_dropout_hidden, kernel_dim, strides, num_filters)}
      trans_args = {'value':tran_inp, 'output_shape':get_output_shape(tran_inp, kernel_dim, strides, num_filters)}
    else :
      f =  tf.nn.conv2d
      rew_args =  {'input':reward_gaussian_dropout_hidden} 
      trans_args = {'input':tran_inp} 

    with tf.variable_scope("conv_layer"):        
      filter_conv = tf.get_variable("filter_conv", filter_shape)

      reward_hidden_layer = f(**rew_args, 
                                filter = filter_conv,
                                strides = strides,
                                padding = 'SAME')
      transition_hidden_layer = f(**trans_args, 
                                filter = filter_conv,
                                strides = strides,
                                padding = 'SAME')
   
    return common_layers.belu(reward_hidden_layer), common_layers.belu(transition_hidden_layer)

  def transition_reward_no_noise_shared_conv(self, rew_inp, tran_inp, kernel_dim, strides, num_filters, transpose=False) :
    filter_shape = kernel_dim + [common_layers.shape_list(rew_inp)[-1],  num_filters]
    if transpose : 
      f = tf.nn.conv2d_transpose
      filter_shape[2], filter_shape[3] = filter_shape[3], filter_shape[2]
      rew_args = {'value':rew_inp, 'output_shape':get_output_shape(rew_inp, kernel_dim, strides, num_filters)}
      trans_args = {'value':tran_inp, 'output_shape':get_output_shape(tran_inp, kernel_dim, strides, num_filters)}
    else :
      f = tf.nn.conv2d
      rew_args = {'input':rew_inp} 
      trans_args = {'input':tran_inp}


    with tf.variable_scope("conv_layer"):
      filter_conv = tf.get_variable("filter_conv", filter_shape)
      reward_hidden_layer = f(**rew_args, 
                                filter = filter_conv,
                                strides = strides,
                                padding = 'SAME')

      transition_hidden_layer = f(**trans_args, 
                                filter = filter_conv,
                                strides = strides,
                                padding = 'SAME')

    return common_layers.belu(reward_hidden_layer), common_layers.belu(transition_hidden_layer)


  def gaussian_dropout_conv_weighted_channel(self, input_layer, kernel_dim, strides = None) :
    out_filter = common_layers.shape_list(input_layer)[-1]
    filter_shape = kernel_dim + [common_layers.shape_list(input_layer)[-1]] + [out_filter] 
    if self.zero_noise :
      noise = tf.zeros(filter_shape)
    else:  
      noise = sample_noise(filter_shape)
    return self.gaussian_conv_weighted_channel(input_layer, noise, kernel_dim,strides = strides)
    
  
  def gaussian_conv_weighted_channel(self, input_layer, noise, kernel_dim, strides=None) :
    with tf.variable_scope("gaussian_dropout_weighted_channel"):  
      # 1x1 convolution. So, filter height, width is 1,1 in_channels = out_channels
      out_filter = common_layers.shape_list(input_layer)[-1]
      filter_shape = kernel_dim + [common_layers.shape_list(input_layer)[-1]]+ [out_filter]   
      filter_mu = tf.get_variable("filter_mu", filter_shape)
      noise = tf.reshape(noise, shape=filter_mu.get_shape())
      filter_sigma = tf.get_variable("filter_sigma",  filter_shape)
      f_strides = [1,1,1,1] if strides is None else strides

      noisy_filter = filter_mu * (tf.ones_like(filter_mu) + tf.multiply(filter_sigma, noise))
      channel_mask = tf.eye(common_layers.shape_list(input_layer)[-1])
      noisy_filter = noisy_filter*channel_mask

      f = tf.nn.conv2d
      args = {'input':input_layer} 
      gaussian_dropout_conv_layer = f(**args, 
                                filter = noisy_filter,
                                strides = f_strides, 
                                padding = 'SAME')
      return common_layers.belu(gaussian_dropout_conv_layer)



def get_unique_name(name) :
  split_list = name.split('/')
  start = len(split_list) -1 - split_list[::-1].index('body')
  return '/'.join(split_list[start+1:])

def sample_noise(shape):
    noise = tf.random_normal(shape)
    return noise

def sample_noise_cpu(shape):
    noise = np.random.normal(size=shape)
    return noise

def get_output_shape(tensor, filter_size, strides, filters) :
  batch_size, height, width, channels = common_layers.shape_list(tensor)
  kernel_h, kernel_w = filter_size
  stride_h, stride_w = strides[1], strides[2]

  out_height = conv_utils.deconv_output_length(height,
                                               kernel_h,
                                               padding="same",
                                               output_padding=None,
                                               stride=stride_h,
                                               dilation=1)
  out_width = conv_utils.deconv_output_length(width,
                                              kernel_w,
                                              padding="same",
                                              output_padding=None,
                                              stride=stride_w,
                                              dilation=1)
  return [batch_size,out_height, out_width,filters]
