# -*- coding: utf-8 -*-
"""
Created on Thu Sep 19 13:54:18 2024
"""

import logging
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
import tensorflow as tf

import tensorflow_text

#%% Fix random seed
import random
seed = 0
tf.random.set_seed(seed)
random.seed(seed)
np.random.seed(seed)

#%% Data handling

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en',
                               with_info=True,
                               as_supervised=True)

train_examples, val_examples = examples['train'], examples['validation']

for pt_examples, en_examples in train_examples.batch(3).take(1):
  print('> Examples in Portuguese:')
  for pt in pt_examples.numpy():
    print(pt.decode('utf-8'))
  print()

  print('> Examples in English:')
  for en in en_examples.numpy():
    print(en.decode('utf-8'))
    
model_name = 'ted_hrlr_translate_pt_en_converter'
tf.keras.utils.get_file(
    f'{model_name}.zip',
    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',
    cache_dir='.', cache_subdir='', extract=True
)

tokenizers = tf.saved_model.load(model_name)

[item for item in dir(tokenizers.en) if not item.startswith('_')]

print('> This is a batch of strings:')
for en in en_examples.numpy():
  print(en.decode('utf-8'))
  
encoded = tokenizers.en.tokenize(en_examples)

print('> This is a padded-batch of token IDs:')
for row in encoded.to_list():
  print(row)
  
round_trip = tokenizers.en.detokenize(encoded)

print('> This is human-readable text:')
for line in round_trip.numpy():
  print(line.decode('utf-8'))
  
print('> This is the text split into tokens:')
tokens = tokenizers.en.lookup(encoded)
tokens

lengths = []

for pt_examples, en_examples in train_examples.batch(1024):
  pt_tokens = tokenizers.pt.tokenize(pt_examples)
  lengths.append(pt_tokens.row_lengths())

  en_tokens = tokenizers.en.tokenize(en_examples)
  lengths.append(en_tokens.row_lengths())
  print('.', end='', flush=True)
  
all_lengths = np.concatenate(lengths)

plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Maximum tokens per example: {max_length}');
plt.savefig('MaxToken.jpg')
plt.show()

MAX_TOKENS=128
def prepare_batch(pt, en):
    pt = tokenizers.pt.tokenize(pt)      # Output is ragged.
    pt = pt[:, :MAX_TOKENS]    # Trim to MAX_TOKENS.
    pt = pt.to_tensor()  # Convert to 0-padded dense Tensor

    en = tokenizers.en.tokenize(en)
    en = en[:, :(MAX_TOKENS+1)]
    en_inputs = en[:, :-1].to_tensor()  # Drop the [END] tokens
    en_labels = en[:, 1:].to_tensor()   # Drop the [START] tokens

    return (pt, en_inputs), en_labels
    
BUFFER_SIZE = 20000
BATCH_SIZE = 64

def make_batches(ds):
  return (
      ds
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE)
      .map(prepare_batch, tf.data.AUTOTUNE)
      .prefetch(buffer_size=tf.data.AUTOTUNE))
      
#%% Test the Dataset

# Create training and validation set batches.
train_batches = make_batches(train_examples)
val_batches = make_batches(val_examples)

for (pt, en), en_labels in train_batches.take(1):
  break

print(pt.shape)
print(en.shape)
print(en_labels.shape)

print(en[0][:10])
print(en_labels[0][:10])

#%% Define the components

def positional_encoding(length, depth):
  depth = depth/2

  positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
  depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

  angle_rates = 1 / (10000**depths)         # (1, depth)
  angle_rads = positions * angle_rates      # (pos, depth)

  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1) # I need to change here

  return tf.cast(pos_encoding, dtype=tf.float32)
  
pos_encoding = positional_encoding(length=2048, depth=512)

# Check the shape.
print(pos_encoding.shape)

# Plot the dimensions.
plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.savefig('PosEnc.jpg')
plt.show()

pos_encoding/=tf.norm(pos_encoding, axis=1, keepdims=True)
p = pos_encoding[1000]
dots = tf.einsum('pd,d -> p', pos_encoding, p)
plt.subplot(2,1,1)
plt.plot(dots)
plt.ylim([0,1])
plt.plot([950, 950, float('nan'), 1050, 1050],
         [0,1,float('nan'),0,1], color='k', label='Zoom')
plt.legend()
plt.subplot(2,1,2)
plt.plot(dots)
plt.xlim([950, 1050])
plt.ylim([0,1])
plt.savefig('EncVec.jpg')
plt.show()

class PositionalEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) 
    self.pos_encoding = positional_encoding(length=2048, depth=d_model)
    self.eps = 1e-05

  def compute_mask(self, *args, **kwargs):
    return self.embedding.compute_mask(*args, **kwargs)

  def call(self, x):
    length = tf.shape(x)[1]
    batch_size = tf.shape(x)[0]
    x = self.embedding(x)
    # This factor sets the relative scale of the embedding and positonal_encoding.
    # x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # (batch, seq_len, emb_dim)
    
    # Batch norm?
    # Apply LayerNorm on the second dimension before adding positional encodings
    # ref: https://discuss.pytorch.org/t/nn-layernorm-for-a-specific-dimension-of-my-tensor/77274
    x = tf.transpose(x, perm=[0, 2, 1]) # [B, T, E] --> [B, E, T]
    # u = x.mean(-1, keepdim=True)
    u = tf.reduce_mean(x, axis=-1, keepdims=True)
    # s = (x - u).pow(2).mean(-1, keepdim=True)
    s = tf.reduce_mean(tf.pow((x - u),2), axis=-1, keepdims=True)
    x = (x - u) / tf.sqrt(s + self.eps)
    # x = self.weight * x + self.bias
    x = tf.transpose(x, perm=[0, 2, 1]) # [B, E, T] --> [B, T, E]
    
    # print('x:'+str(tf.shape(x)))
    z = self.pos_encoding[tf.newaxis, :length, :] # (1, seq_len, emb_dim)
    # print('z before:'+str(tf.shape(z)))
    z = tf.tile(z, [batch_size,1,1]) # (batch, seq_len, emb_dim)
    # print('z after:'+str(tf.shape(z)))
    y_all = tf.concat([x,z],axis=-1) # (batch, seq_len, 2*emb_dim)
    
    # #%% Old version
    # for i in range(length.numpy()):
        # e = x[:,i,:] # (batch, emb_dim)
        # e = e[:,tf.newaxis,:] # (batch, 1, emb_dim)
        # # print('e:'+str(tf.shape(e)))
        # p = z[:,i,:] # (batch, emb_dim)
        # p = p[:,tf.newaxis,:] # (batch, 1, emb_dim)
        # # print('p:'+str(tf.shape(p)))
        # y = tf.concat([e,p],axis=-1) # # (batch, 1, 2*emb_dim)
        # if i == 0:
            # y_all = y # (batch, 1, 2*emb_dim)
        # else:
            # y_all = tf.concat([y_all, y],axis=1) # (batch, i+1, 2*emb_dim)
            
    return y_all, x, z
    
embed_pt = PositionalEmbedding(vocab_size=tokenizers.pt.get_vocab_size().numpy(), d_model=512)
embed_en = PositionalEmbedding(vocab_size=tokenizers.en.get_vocab_size().numpy(), d_model=512)

pt_emb, pt_tk_emb, pt_p = embed_pt(pt)
en_emb, en_tk_emb, en_p = embed_en(en)

en_emb._keras_mask

class BaseAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.layernorm = tf.keras.layers.LayerNormalization()
    self.add = tf.keras.layers.Add()
    
class CrossAttention(BaseAttention):
  def call(self, x, context, x_tk_emb, context_tk_emb, p):
    attn_output, attn_scores = self.mha(
        query=x,
        key=context,
        value=context_tk_emb, # change here
        return_attention_scores=True)

    # Cache the attention scores for plotting later.
    self.last_attn_scores = attn_scores

    x = self.add([x, attn_output])
    # x = self.add([x_tk_emb, attn_output])
    # x = tf.concat([x,p],axis=-1) # change here
    x = self.layernorm(x)

    return x
    
sample_ca = CrossAttention(num_heads=2, key_dim=1024, output_shape=1024)

print(pt_emb.shape)
print(en_emb.shape)
print(sample_ca(en_emb, pt_emb, en_tk_emb, pt_tk_emb, en_p).shape)

class GlobalSelfAttention(BaseAttention):
  def call(self, x, tk_emb, p):
    attn_output = self.mha(
        query=x,
        value=tk_emb, # change here
        key=x)
    x = self.add([x, attn_output])
    # x = self.add([tk_emb, attn_output])
    # x = tf.concat([x,p],axis=-1) # change here
    x = self.layernorm(x)
    return x
    
sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=1024, output_shape=1024)

print(pt_emb.shape)
print(sample_gsa(pt_emb, pt_tk_emb, pt_p).shape)

class CausalSelfAttention(BaseAttention):
  def call(self, x, tk_emb, p):
    attn_output = self.mha(
        query=x,
        value=tk_emb, # change here
        key=x,
        use_causal_mask = True)
    x = self.add([x, attn_output])
    # x = self.add([tk_emb, attn_output])
    # x = tf.concat([x,p],axis=-1) # change here
    x = self.layernorm(x)
    return x
    
sample_csa = CausalSelfAttention(num_heads=2, key_dim=1024, output_shape=1024)

print(en_emb.shape)
print(sample_csa(en_emb, en_tk_emb, en_p).shape)

temp, en_tk_emb, _ = embed_en(en[:, :3])
out1 = sample_csa(temp, en_tk_emb, _) 
temp, en_tk_emb, _ = embed_en(en)
out2 = sample_csa(temp, en_tk_emb, _)[:, :3]

tf.reduce_max(abs(out1 - out2)).numpy()

class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, dff, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),
      tf.keras.layers.Dense(d_model),
      tf.keras.layers.Dropout(dropout_rate)
    ])
    self.add = tf.keras.layers.Add()
    self.layer_norm = tf.keras.layers.LayerNormalization()

  def call(self, x, p):
    # x = tf.concat([self.seq(x)[:,:,:-2],p],axis=-1) # change here
    x = self.add([x, self.seq(x)])
    x = self.layer_norm(x) 
    return x
    
sample_ffn = FeedForward(1024, 2048)

print(en_emb.shape)
print(sample_ffn(en_emb, en_p).shape)

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):
    super().__init__()

    self.self_attention = GlobalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model*2, dff)

  def call(self, x, tk_emb, p):
    x = self.self_attention(x, tk_emb, p)
    x = self.ffn(x, p) # change here
    return x
    
sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=1024)

print(pt_emb.shape)
print(sample_encoder_layer(pt_emb, pt_tk_emb, pt_p).shape)

class Encoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads,
               dff, vocab_size, dropout_rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(
        vocab_size=vocab_size, d_model=d_model)

    self.enc_layers = [
        EncoderLayer(d_model=d_model,
                     num_heads=num_heads,
                     dff=dff,
                     dropout_rate=dropout_rate)
        for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

  def call(self, x):
    # `x` is token-IDs shape: (batch, seq_len)
    x, tk_emb, p = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`. # change here

    # Add dropout.
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, tk_emb, p)

    return x, tk_emb # Shape `(batch_size, seq_len, d_model)`.
    
# Instantiate the encoder.
sample_encoder = Encoder(num_layers=4,
                         d_model=512,
                         num_heads=8,
                         dff=2048,
                         vocab_size=8500)

sample_encoder_output, pt_tk_emb = sample_encoder(pt, training=False)

# Print the shape.
print(pt.shape)
print(sample_encoder_output.shape)  # Shape `(batch_size, input_seq_len, d_model)`.

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,
               *,
               d_model,
               num_heads,
               dff,
               dropout_rate=0.1):
    super(DecoderLayer, self).__init__()

    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.cross_attention = CrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model*2, dff)

  def call(self, x, context, x_tk_emb, context_tk_emb, p): # change here
    x = self.causal_self_attention(x=x, tk_emb=x_tk_emb, p=p)
    x = self.cross_attention(x=x, context=context, x_tk_emb=x_tk_emb, context_tk_emb=context_tk_emb, p=p)

    # Cache the last attention scores for plotting later
    self.last_attn_scores = self.cross_attention.last_attn_scores

    x = self.ffn(x, p)  # Shape `(batch_size, seq_len, d_model)`. # change here
    return x
    
sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=1024)

sample_decoder_layer_output = sample_decoder_layer(
    x=en_emb, context=pt_emb, x_tk_emb=en_tk_emb, context_tk_emb=pt_tk_emb, p=en_p)

print(en_emb.shape)
print(pt_emb.shape)
print(sample_decoder_layer_output.shape)  # `(batch_size, seq_len, d_model)`

class Decoder(tf.keras.layers.Layer):
  def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
               dropout_rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                             d_model=d_model)
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads,
                     dff=dff, dropout_rate=dropout_rate)
        for _ in range(num_layers)]

    self.last_attn_scores = None

  def call(self, x, context, e_tk_emb):
    # `x` is token-IDs shape (batch, target_seq_len)
    x, d_tk_emb, p = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model) # change here

    x = self.dropout(x)

    for i in range(self.num_layers):
      x  = self.dec_layers[i](x, context, d_tk_emb, e_tk_emb, p) # change here

    self.last_attn_scores = self.dec_layers[-1].last_attn_scores

    # The shape of x is (batch_size, target_seq_len, d_model).
    return x
    
# Instantiate the decoder.
sample_decoder = Decoder(num_layers=4,
                         d_model=512,
                         num_heads=8,
                         dff=2048,
                         vocab_size=8000)

output = sample_decoder(
    x=en,
    context=pt_emb,
    e_tk_emb=pt_tk_emb)

# Print the shapes.
print(en.shape)
print(pt_emb.shape)
print(output.shape)

sample_decoder.last_attn_scores.shape  # (batch, heads, target_seq, input_seq)

#%% The Transformer

class Transformer(tf.keras.Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff,
               input_vocab_size, target_vocab_size, dropout_rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=input_vocab_size,
                           dropout_rate=dropout_rate)

    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
                           num_heads=num_heads, dff=dff,
                           vocab_size=target_vocab_size,
                           dropout_rate=dropout_rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inputs):
    # To use a Keras model with `.fit` you must pass all your inputs in the
    # first argument.
    context, x  = inputs

    context, e_tk_emb = self.encoder(context)  # (batch_size, context_len, d_model)

    x = self.decoder(x, context, e_tk_emb)  # (batch_size, target_len, d_model)

    # Final linear layer output.
    logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

    try:
      # Drop the keras mask, so it doesn't scale the losses/metrics.
      # b/250038731
      del logits._keras_mask
    except AttributeError:
      pass

    # Return the final output and the attention weights.
    return logits
    
num_layers = 2 # change here 4 -> 2
d_model = 64 # change here 128 -> 64
dff = 256 # change here 512 -> 256
num_heads = 4 # change here 8 -> 4
dropout_rate = 0.1

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),
    target_vocab_size=tokenizers.en.get_vocab_size().numpy(),
    dropout_rate=dropout_rate)
    
output = transformer((pt, en))

print(en.shape)
print(pt.shape)
print(output.shape)

attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores
print(attn_scores.shape)  # (batch, heads, target_seq, input_seq)

# transformer.summary()

print("\n\nCheck end\n\n")

#%% Training

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super().__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    step = tf.cast(step, dtype=tf.float32)
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
learning_rate = CustomSchedule(d_model*2)  # change here

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
                                     
plt.plot(learning_rate(tf.range(40000, dtype=tf.float32)))
plt.ylabel('Learning Rate')
plt.xlabel('Train Step')
plt.savefig('lr.jpg')
plt.show()

def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)
  
transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy],
    run_eagerly=True)
    
history = transformer.fit(train_batches,
                epochs=10,
                validation_data=val_batches)
                
train_loss = history.history['loss']
val_loss = history.history['val_loss']

np.save('train_loss.npy', train_loss)
np.save('val_loss.npy', val_loss)

train_accuracy = history.history.get('accuracy')  # Use 'acc' for older versions
val_accuracy = history.history.get('val_accuracy')  # Use 'val_acc' for older versions

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))

# Plot training & validation loss values
plt.plot(train_loss, label='Training Loss', linewidth=3)  # Increase line thickness
plt.plot(val_loss, label='Validation Loss', linewidth=3)  # Increase line thickness
plt.title('Model Loss', fontsize=30)  # Increase title font size
plt.xlabel('Epoch', fontsize=25)      # Increase xlabel font size
plt.ylabel('Loss', fontsize=25)       # Increase ylabel font size
plt.legend(loc='upper right', fontsize=20)  # Increase legend font size
plt.grid(True)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig('loss.jpg', bbox_inches='tight')
plt.show()

if train_accuracy and val_accuracy:
    plt.figure(figsize=(12, 6))

    # Plot training & validation accuracy values
    plt.plot(train_accuracy, label='Training Accuracy')
    plt.plot(val_accuracy, label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.savefig('accuracy.jpg')
    plt.show()