"""Build model and coordinate training."""

import os
from typing import Sequence
import tensorflow as tf

from src import models, train

_MAX_PROMPT_LEN = 5 
_INPUT_SIZE = 20
_NUM_COMPONENTS = 5 
_NOISE_LEVEL = 1.0 


def train_GPT2_models(input_size, 
                      d_init, 
                      k_init, 
                      k_max, 
                      num_components,
                      noise_level,
                      nbatch_curriculum, 
                      nbatch_final,
                      num_embed=256,
                      num_heads=8, 
                      num_layers=12): 
  """
  Trains a GPT2 decoder only model 
  Args: 
  - input_size: dimension of input tokens
  - d_init, k_init: dimension (d_init), and prompt_length (k_init) params for 
    curriculum training initialization
  - data_gen: takes (d, k), outputs data. 
  - nbatch_curriculum: the number of batches used during each phase 
    of curriculum_training
  - nbatch_final: the number of batches used during the final phase of 
    training (when d = input_size, k = 2 * d + 1)
  - num_embed: embedding dimension (default 256)
  - num_heads: number of heads for MHA layer (default 8)
  - num_layers: number of decoder layers (default 12)
  - lr_rate: learning rate for Adam (default 0.0001)
  """
  # make a list of all possible maximum prompt lengths for training time
  base_dir = 'path-here'
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    model = models.DecoderOnlyGPT2(
        input_size, num_embed, num_heads, num_layers, dropout_rate=0.0
    )
  optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
  batch_size = 64
  # get experiment directory (k = maximum prompt length, m = number of clusters)
  expt_dir = os.path.join(base_dir, '_'.join([
      'noise_level', 
      str(noise_level), 
      'm', 
      str(num_components), 
      'd', 
      str(input_size)
  ]))
  # launch training loop
  train.training_loop(
    model,
    optimizer,
    input_size,
    d_init,
    k_init,
    k_max,
    batch_size,
    nbatch_curriculum,
    nbatch_final,
    num_components,
    noise_level,
    expt_dir
    )

def main(argv: Sequence[str]) -> None:
  """Launches GPT model training loops."""
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  nbatch_final = 250000
  nbatch_curriculum = 2000
  
  train_GPT2_models(
    input_size=_INPUT_SIZE, 
    d_init=5, 
    k_init=11, 
    k_max=_MAX_PROMPT_LEN, 
    num_components=_NUM_COMPONENTS,
    noise_level=_NOISE_LEVEL,
    nbatch_curriculum=nbatch_curriculum, 
    nbatch_final=nbatch_final,
    num_embed=256,
    num_heads=8, 
    num_layers=12, 
  )


if __name__ == '__main__':
  app.run(main)