"""Main driving training code."""

import multiprocessing as mp
import os
import random
import string
import sys
import time

from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.compat.v2 import keras
from tqdm import tqdm


import data_loader
import fixed_rnn

FLAGS = flags.FLAGS

flags.DEFINE_integer('train_epochs', 40, 'Number of epochs to train')
flags.DEFINE_integer('batch_size', 512,
                     'Batch size for the randomly sampled batch')
flags.DEFINE_float('learning_rate', 2.77e-3, 'Learning rate')

### Non tunable flags
flags.DEFINE_string('expt', 'exp_first', 'The name of the experiment dir')
flags.DEFINE_string('dataset', 'favorita', 'The name of the experiment dir')
flags.DEFINE_string('all_data_path',
                    './datasets',
                    'Path to directory with all file paths')
flags.DEFINE_string('inference_class', 'MixZeroNBinomNormal',
                    'inference class for predictions.')
flags.DEFINE_string(
    'config',
    '{"max_val": 10000, "num_samples": 20000, "min_val": 1, "quantile": 0.50}',
    'config for mle inference.')
flags.DEFINE_string('error_metric', 'MSE',
                    'error metric for mle inference.')
flags.DEFINE_string('logging_path',
                    './log_dirs',
                    'Path to directory where logging will be saved.')
flags.DEFINE_string('loss_function', 'mix_znbg_mle',
                    'which loss function to use for train')

flags.DEFINE_integer('hist_len', 28, 'Length of the history provided as input')
flags.DEFINE_integer('train_pred', 14, 'Length of pred len during training')
flags.DEFINE_integer('test_pred', 14, 'Length of pred len during test/val')
flags.DEFINE_integer('val_windows', 1, 'Number of validation windows')
flags.DEFINE_integer('test_windows', 1, 'Number of validation windows')

flags.DEFINE_integer('fixed_lstm_hidden', 256,
                     'Number of LSTM hidden units in the local model')
flags.DEFINE_integer('output_dim', 7, 'Output dimension of the model.')
flags.DEFINE_integer('random_seed', None,
                     'The random seed to be used for TF and numpy')
flags.DEFINE_integer('patience', 10, 'Patience for early stopping')
flags.DEFINE_integer('num_changes', 6, 'Number of changes in the learning rate')
flags.DEFINE_float('huber_delta', 64, 'huber loss delta.')
flags.DEFINE_float('quantile', 0.50, 'quantile for quantile loss training.')
flags.DEFINE_float('alpha', 3.0, 'alpha for pareto.')


def _get_random_string(num_chars):
  rand_str = ''.join(
      random.choice(string.ascii_uppercase + string.ascii_lowercase +
                    string.digits) for _ in range(num_chars - 1))
  return ''.join(['_', rand_str])


def training():
  """Trainign ts code."""
  tf.random.set_seed(FLAGS.random_seed)
  np.random.seed(FLAGS.random_seed)

  experiment_id = _get_random_string(8)

  # Load data
  data = data_loader.Data()
  # Create model
  model = fixed_rnn.FixedRNN(
      num_ts=data.num_ts,
      cat_dims=data.global_cat_dims,
      tree=data.tree,
      output_dim=FLAGS.output_dim)
  # Compute path to experiment directory
  expt_dir = os.path.join(
      FLAGS.logging_path, FLAGS.dataset, str(experiment_id))
  os.makedirs(expt_dir, exist_ok=True)

  step = tf.Variable(0)
  # LR scheduling
  boundaries = FLAGS.train_epochs * np.linspace(0.0, 1.0, FLAGS.num_changes)
  boundaries = boundaries.astype(np.int32).tolist()

  lr = FLAGS.learning_rate * np.asarray(
      [0.5**i for i in range(FLAGS.num_changes + 1)])
  lr = lr.tolist()

  sch = keras.optimizers.schedules.PiecewiseConstantDecay(
      boundaries=boundaries, values=lr)
  optimizer = keras.optimizers.Adam(learning_rate=lr[0], clipvalue=1e5)

  summary = Summary(expt_dir)

  best_loss = 1e7
  pat = 0

  while step.numpy() < FLAGS.train_epochs:
    ep = step.numpy()
    logging.info('Epoch %s', ep)
    sys.stdout.flush()
    optimizer.learning_rate.assign(sch(step))

    iterator = tqdm(data.tf_dataset(mode='train'), mininterval=2)
    for i, (feats, y_obs, nid) in enumerate(iterator):
      loss = model.train_step(feats, y_obs, nid, optimizer)
      # Train metrics
      summary.update({'train/reg_loss': loss, 'train/loss': loss})
      if i % 100 == 0:
        mean_loss = summary.metric_dict['train/reg_loss'].result().numpy()
        iterator.set_description(f'Reg + Loss {mean_loss:.4f}')
    step.assign_add(1)
    # ckpt_manager.save()
    # Other metrics
    summary.update({'train/learning_rate': optimizer.learning_rate.numpy()})
    # Test metrics
    val_metrics, val_pred, val_loss = model.eval(data, 'val')
    test_metrics, test_pred, test_loss = model.eval(data, 'test')
    logging.info('Val Loss: %s', val_loss)
    logging.info('Test Loss: %s', test_loss)
    tracked_loss = val_loss
    if tracked_loss < best_loss:
      best_loss = tracked_loss
      # best_check_path = ckpt_manager.latest_checkpoint
      pat = 0

      val_metrics.to_csv(os.path.join(expt_dir, 'val_metrics.csv'))
      test_metrics.to_csv(os.path.join(expt_dir, 'test_metrics.csv'))

      logging.info('saved best result so far at %s', expt_dir)
    else:
      pat += 1
      if pat > FLAGS.patience:
        logging.info('Early stopping')
        time.sleep(4)
        break
    time.sleep(4)

    summary.write(step=step.numpy())


class Summary:
  """Summary statistics."""

  def __init__(self, log_dir):
    self.metric_dict = {}
    self.writer = tf.summary.create_file_writer(log_dir)

  def update(self, update_dict):
    for metric in update_dict:
      if metric not in self.metric_dict:
        self.metric_dict[metric] = keras.metrics.Mean()
      self.metric_dict[metric].update_state(values=[update_dict[metric]])

  def write(self, step):
    with self.writer.as_default():
      for metric in self.metric_dict:
        tf.summary.scalar(metric, self.metric_dict[metric].result(), step=step)
    self.metric_dict = {}
    self.writer.flush()


def main(_):
  training()


if __name__ == '__main__':
  mp.set_start_method('spawn', force=True)  # Can be 'spawn' as well.
  app.run(main)
