
"""
original code: https://github.com/Lsdefine/attention-is-all-you-need-keras
"""

import tensorflow as tf
from tensorflow.keras.callbacks import *
import tensorflow.keras.backend as K

class LRSchedulerPerStep(Callback):

	def __init__(self, d_model, warmup=4000):
		self.basic = d_model ** -0.5
		self.warm = warmup ** -1.5
		self.step_num = 0

	def on_batch_begin(self, batch, logs=None):
		self.step_num += 1
		lr = self.basic * min(self.step_num ** -0.5, self.step_num * self.warm)
		K.set_value(self.model.optimizer.lr, lr)

        
class LRSchedulerPerEpoch(Callback):

	def __init__(self, d_model, warmup=4000, num_per_epoch=1000):
		self.basic = d_model ** -0.5
		self.warm = warmup ** -1.5
		self.num_per_epoch = num_per_epoch
		self.step_num = 1

	def on_epoch_begin(self, epoch, logs=None):
		self.step_num += self.num_per_epoch
		lr = self.basic * min(self.step_num ** -0.5, self.step_num * self.warm)
		K.set_value(self.model.optimizer.lr, lr)


# class AddPosEncoding:
#
# 	def __call__(self, x):
# 		_, max_len, d_emb = K.int_shape(x)
# 		pos = GetPosEncodingMatrix(max_len, d_emb)
# 		x = Lambda(lambda x:x + pos)(x)
# 		return x


def slice_from_to(x, initial, final):
    # None can be used where initial or final, so
    # [1:] = [1:None]
    return x[:, initial:final]


def clip_layer(inputs, min_value, max_value):
    eps = .5e-6
    clipped_point = tf.clip_by_value(inputs, min_value + eps, max_value - eps)
    return clipped_point


def replace_column(matrix, new_column, r):
    dynamic_index = tf.cast(tf.squeeze(r), dtype=tf.int64)
    matrix = tf.cast(matrix, dtype=tf.float32)
    new_column = tf.cast(new_column, dtype=tf.float32)
    num_cols = tf.shape(matrix)[1]
    # new_matrix = tf.assign(matrix[:, dynamic_index], new_column)
    index_row = tf.stack([tf.eye(num_cols, dtype=tf.float32)[dynamic_index, :]])
    old_column = matrix[:, dynamic_index]
    new = tf.matmul(tf.stack([new_column], axis=1), index_row)
    old = tf.matmul(tf.stack([old_column], axis=1), index_row)
    new_matrix = (matrix - old) + new
    return new_matrix
