"""Training loop."""

import os
import numpy as np
from src import data, util

def training_loop(
    model,
    optimizer,
    input_size,
    d_init,
    k_init,
    k_max,
    batch_size,
    nbatch_curriculum,
    nbatch_final,
    num_components,
    noise_level,
    directory,
    log_freq=1000,
):
  """Training loop for GPT model.
  This loop uses fresh data on each iteration. 
  """

  # initially keep curriculum training on
  curriculum_training = True

  # start at the initial dimension and prompt length 
  d, k = d_init, k_init

  # sample some cluster centers 
  centers = np.random.randn(num_components, input_size)
  centers = (centers / np.linalg.norm(centers, axis=1)[:, None]) * np.sqrt(
      input_size
  )
  
  # cluster centers 
  with open(os.path.join(directory, 'mixture_centers.npy'), 'wb') as fp:
    np.save(fp, centers)

  # create data generators for validation and training 
  train_gen = lambda d, k: data.generate_batch(
      d, k, batch_size, input_size, noise_level=noise_level, centers=centers
  )
  val_gen = lambda: data.generate_batch(
      input_size,
      k_max,
      batch_size,
      input_size,
      noise_level=noise_level,
      return_reg=True,
      centers=centers,
  )[3:]

  nbatch_total = nbatch_curriculum * (input_size - d_init) + nbatch_final
  nbatch_remaining = nbatch_total
  best_val_loss = np.Inf
  count = 0

  # training loop 
  while nbatch_remaining > 0:
    # set number of batches for current (d, k) depending on phase
    # of training
    num_batches = nbatch_curriculum if curriculum_training else nbatch_final

    # never exceed number of batches remaining
    num_batches = min(num_batches, nbatch_remaining)

    # runs training steps for current setup (d, k, num_batches)
    for i in range(num_batches):
      prompts, all_labels = train_gen(d, k)
      train_loss_value = model.train_step(prompts, all_labels, optimizer)

      # logging
      curr_it = nbatch_total - nbatch_remaining + i
      if curr_it % log_freq == 0:
        labels, prompts = val_gen()
        val_loss_value = util.get_model_error(model, prompts, labels)

        # logging model checkpoints 
        util.save_weights(model, 
                          os.path.join(directory, 
                                        'chkpt_model_{}.h5'.format(count * log_freq)))
        count += 1
        if val_loss_value < best_val_loss:
          best_val_loss = val_loss_value
          util.save_weights(model, os.path.join(directory, 'saved_model.h5'))
          
        
    # remove batches from remaining batches
    nbatch_remaining -= num_batches

    if d < input_size:
      d += 1
      # increment k by two until reaching maximum prompt length
      k = (k + 2) if k + 2 < k_max else k_max
      curriculum_training = (d != input_size)

def fixed_sample_size_training_loop(
    model,
    optimizer,
    input_size,
    d_init,
    k_init,
    k_max,
    batch_size,
    sample_size,
    nbatch_curriculum,
    nbatch_final,
    num_components,
    directory,
    log_freq=100,
    use_noise=False
):
  """Training loop for GPT model.
  """

  curriculum_training = True

  k = k_init 
  noise_level = 1.0 if use_noise else 0.0
  num_clusters = num_components
  centers = np.random.randn(num_clusters, input_size)
  centers = (centers / np.linalg.norm(centers, axis=1)[:, None]) * np.sqrt(
      input_size
  )
  dataset = data.generate_batch(d=input_size, 
                                k=k_max, 
                                batch_size=sample_size,
                                input_size=input_size, 
                                noise_level=noise_level, 
                                centers=centers)
    
  with open(os.path.join(directory, 'mixture_centers.npy'), 'wb') as fp:
    np.save(fp, centers)
  with open(os.path.join(directory, 'dataset_prompts.npy'), 'wb') as fp:
    np.save(fp, dataset[0])
  with open(os.path.join(directory, 'dataset_labels.npy'), 'wb') as fp:
    np.save(fp, dataset[1])

  train_gen = lambda k: data.generate_batch_from_dataset(
      k, batch_size, dataset=dataset
  )
  val_gen = lambda: data.generate_batch(
      input_size,
      k_max,
      batch_size,
      input_size,
      return_reg=True,
      noise_level=noise_level,
      centers=centers,
  )[3:]

  nbatch_total = nbatch_curriculum * (input_size - d_init) + nbatch_final

  nbatch_remaining = nbatch_total

  best_val_loss = np.Inf

  while nbatch_remaining > 0:
    # set number of batches for current (d, k) depending on phase
    # of training
    num_batches = nbatch_curriculum if curriculum_training else nbatch_final
    
    # never exceed number of batches remaining
    num_batches = min(num_batches, nbatch_remaining)

    # runs training steps for current setup (d, k, num_batches)
    for i in range(num_batches):
      prompts, all_labels = train_gen(k)
      train_loss_value = model.train_step(prompts, all_labels, optimizer)

      # logging
      curr_it = nbatch_total - nbatch_remaining + i
      if curr_it % log_freq == 0:
        true_labels, prompts = val_gen()
        pred_labels = model(prompts, training=False)[:, -1, 0]
        val_loss_value = (np.linalg.norm(pred_labels - true_labels) ** 2) / (
            input_size * len(true_labels)
        )

        # logging model
        if val_loss_value < best_val_loss:
          best_val_loss = val_loss_value
          util.save_weights(model, os.path.join(directory, 'saved_model.h5'))

    # remove batches from remaining batches
    nbatch_remaining -= num_batches

    if k < k_max:
      k += 1
      curriculum_training = (k != k_max)
  
  
