"""Build model and coordinate training."""

from typing import Sequence
import tensorflow as tf
from src import models, train

_USE_NOISE = False 
_NUM_COMPONENTS = 0 
_MAX_PROMPT_LEN = 5 
_INPUT_SIZE =  20
_LEARNING_RATE = 1e-4
_DROPOUT_RATE = 0.0 
_SAMPLE_SIZE = 0 
_INITIAL_RATE = -1.0 
_DECAY_STEPS = 0 
_MAX_RATE = -1.0 
_WARMUP_STEPS = 0 
_TRIAL = 0 

def train_GPT2_models(input_size, 
                      d_init, 
                      k_init, 
                      k_max, 
                      num_components,
                      nbatch_curriculum, 
                      nbatch_final,
                      lr_rate,
                      sample_size, 
                      dropout_rate,
                      num_embed=256,
                      num_heads=8, 
                      num_layers=12, 
                      use_noise=False): 
  """
  Trains a GPT2 decoder only model 
  Args: 
  - input_size: dimension of input tokens
  - d_init, k_init: dimension (d_init), and prompt_length (k_init) params for 
    curriculum training initialization
  - data_gen: takes (d, k), outputs data. 
  - nbatch_curriculum: the number of batches used during each phase 
    of curriculum_training
  - nbatch_final: the number of batches used during the final phase of 
    training (when d = input_size, k = 2 * d + 1)
  - num_embed: embedding dimension (default 256)
  - num_heads: number of heads for MHA layer (default 8)
  - num_layers: number of decoder layers (default 12)
  - lr_rate: learning rate for Adam (default 0.0001)
  """
  # make a list of all possible maximum prompt lengths for training time
  base_dir = 'PATH-HERE'
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    model = models.DecoderOnlyGPT2(
        input_size, num_embed, num_heads, num_layers, dropout_rate=dropout_rate
    )

  initial_rate = _INITIAL_RATE if _INITIAL_RATE >= 0 else lr_rate
  max_rate = _MAX_RATE if _MAX_RATE >= 0 else lr_rate

  lr_scheduler = tf.keras.optimizers.schedules.CosineDecay(
      initial_rate,
      _DECAY_STEPS,
      alpha=lr_rate / max_rate,
      warmup_target=max_rate,
      warmup_steps=_WARMUP_STEPS,
  )
  optimizer = tf.keras.optimizers.Adam(learning_rate=lr_scheduler)
  batch_size = 64
  # get experiment directory (k = maximum prompt length, m = number of clusters)
  expt_dir = '_'.join([
      base_dir,
      'k',
      str(k_max),
      'n',
      str(sample_size),
      'dropout',
      str(dropout_rate),
      'm',
      str(num_components),
      'init',
      str(initial_rate),
      'max',
      str(max_rate),
      'warmsteps',
      str(_WARMUP_STEPS),
      'decaysteps',
      str(_DECAY_STEPS),
      'lr',
      str(lr_rate),
      'trial', 
      str(_TRIAL),
  ])
  # launch training loop
  train.fixed_sample_size_training_loop(
      model,
      optimizer,
      input_size,
      d_init,
      k_init,
      k_max,
      batch_size,
      sample_size,
      nbatch_curriculum,
      nbatch_final,
      num_components,
      expt_dir,
      log_freq=1000,
      use_noise=use_noise,
  )

def main(argv: Sequence[str]) -> None:
  """Launches GPT model training loops."""
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  nbatch_final = 500000
  nbatch_curriculum = 2000
  
  train_GPT2_models(
      input_size=_INPUT_SIZE,
      d_init=5,
      k_init=min(11, _MAX_PROMPT_LEN),
      k_max=_MAX_PROMPT_LEN,
      num_components=_NUM_COMPONENTS,
      nbatch_curriculum=nbatch_curriculum,
      nbatch_final=nbatch_final,
      lr_rate=_LEARNING_RATE, 
      sample_size=_SAMPLE_SIZE, 
      dropout_rate=_DROPOUT_RATE,
      use_noise=_USE_NOISE,
  )


if __name__ == '__main__':
  app.run(main)