"""Define GPT2 and supporting model classes"""

import numpy as np
import tensorflow as tf
from tensorflow import keras


class ResidualLayer(keras.layers.Layer):
  """Residual layer with layer norm."""

  def __init__(self, hidden_dim, output_dim, dropout_rate=0.0):
    super(ResidualLayer, self).__init__()
    self.dropout = keras.layers.Dropout(dropout_rate)
    
    self.lin_a = keras.layers.Dense(
        hidden_dim, activation='relu'
    )
    self.lin_b = keras.layers.Dense(
        output_dim, activation=None
    )
    self.lin_res = keras.layers.Dense(
        output_dim, activation=None
    )
    self.lnorm = keras.layers.LayerNormalization(epsilon=1e-5)

  @tf.function
  def call(self, inputs):
    """Call method."""
    inputs = self.dropout(inputs)
    h_state = self.lin_a(inputs)
    out = self.lin_b(h_state)
    res = self.lin_res(inputs)
    return self.lnorm(out + res)


def positional_encoding(length, depth):
  depth = depth / 2

  positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
  depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

  angle_rates = 1 / (10000**depths)  # (1, depth)
  angle_rads = positions * angle_rates  # (pos, depth)

  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)], axis=-1
  )

  return tf.cast(pos_encoding, dtype=tf.float32)


class PositionalEmbedding(tf.keras.layers.Layer):

  def __init__(self, hidden_size, dropout_rate=0.0):
    super().__init__()
    self.d_model = hidden_size
    self.dropout = keras.layers.Dropout(dropout_rate)
    self.embedding = keras.layers.Dense(
        hidden_size, activation=None
    )
    self.pos_encoding = positional_encoding(length=1024, depth=hidden_size)

  @tf.function
  def call(self, x):
    length = tf.shape(x)[1]
    x = self.dropout(x)
    x = self.embedding(x)
    # This factor sets the relative scale of the embedding & positonal_encoding.
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    return x + self.pos_encoding[tf.newaxis, :length, :]


class DecoderLayer(tf.keras.Model):

  def __init__(self, num_heads, num_hidden, num_output, dropout_rate=0.0):
    super(DecoderLayer, self).__init__()

    """
    num_hidden = number of hidden units in residual layer
    num_output = number of output units in residual layer = encoding/model dim
    """

    self.attention = keras.layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=num_output,
        dropout=dropout_rate,
        attention_axes=(1,),
    )

    self.lnorm = keras.layers.LayerNormalization(epsilon=1e-5)

    self.residual = ResidualLayer(
        num_hidden, num_output, dropout_rate=dropout_rate
    )
    
  @tf.function
  def call(self, inputs):
    x = self.attention(inputs, inputs, use_causal_mask=True) + inputs

    x = self.lnorm(x)

    return self.residual(x)


class DecoderOnlyGPT2(tf.keras.Model):

  def __init__(
      self,
      input_size,
      hidden_size,
      num_heads,
      num_hidden_layers,
      dropout_rate=0.0,
  ):
    super(DecoderOnlyGPT2, self).__init__()

    self.embed = PositionalEmbedding(hidden_size, dropout_rate=dropout_rate)

    self.input_size = input_size

    self.dropout_rate = dropout_rate

    # the hidden dimension for the nn is 4 x the embedding size
    self.decoder_layers = [
        DecoderLayer(
            num_heads, 4 * hidden_size, hidden_size, dropout_rate=dropout_rate
        )
        for _ in range(num_hidden_layers)
    ]

    self.dropout = keras.layers.Dropout(dropout_rate)
    
    self.output_layer = keras.layers.Dense(1, activation=None)

  @tf.function
  def call(self, inputs):
    # Embed the inputs
    x = self.embed(inputs)

    # Pass through the layers
    for curr_layer in self.decoder_layers:
      x = curr_layer(x)

    # output the predictions
    x = self.dropout(x) 
    
    return self.output_layer(x)
  
  @tf.function
  def loss_fn(self, y_true, y_pred):
    """
    Loss - Mean squared error. 

    Args:
      y_true: true values 
      y_pred: predictions

    Returns: loss

    """
    # return MSE / dimension 
    return tf.reduce_mean(tf.math.square(y_pred - y_true))
  
  @tf.function
  def train_step(self, prompts, labels, optimizer):
    """One step of training."""
    with tf.GradientTape() as tape:
      y_pred = self(prompts, training=True)[:, ::2, 0]
      curr_loss = self.loss_fn(labels, y_pred)

    grads = tape.gradient(curr_loss, self.trainable_variables)
    
    # clip grads
    grads = [tf.clip_by_norm(g, 1.0) for g in grads]
    
    optimizer.apply_gradients(zip(grads, self.trainable_variables))
    return curr_loss
  
  @tf.function
  def val_step(self, prompts, labels):
    """One step of evaluation."""
    y_pred = self(prompts, training=False)[:, -1, 0]
    return self.loss_fn(labels, y_pred)