"""Fixed RNN Model."""

import gc
import json
import multiprocessing as mp

from absl import flags
from absl import logging
import numpy as np
import pandas as pd
from tabulate import tabulate
import tensorflow.compat.v2 as tf
from tensorflow.compat.v2 import keras
from tensorflow.compat.v2.keras import layers

import get_predictions as gp
import loss_functions

FLAGS = flags.FLAGS
MAX_FEAT_EMB_DIM = 5
EPS = 1e-7


def inference_transform(y_pred):
  """Run inference transform on outputs."""
  t, n, o = y_pred.shape
  y_pred = y_pred.reshape(t * n, o)
  config = json.loads(FLAGS.config)
  config['alpha'] = FLAGS.alpha
  y_pred = gp.run_inference(
      inference_class=FLAGS.inference_class,
      model_outputs=y_pred,
      config=config,
      error_metric=FLAGS.error_metric,
      num_jobs=mp.cpu_count())
  gc.collect()
  return y_pred.reshape(t, n)


class FixedRNN(keras.Model):
  """Main class for fixed seq2seq."""

  def __init__(self, num_ts, cat_dims, tree, output_dim=1):
    super().__init__()

    self.num_ts = num_ts
    self.cat_dims = cat_dims

    self.tree = tree

    cat_emb_dims = [
        min(MAX_FEAT_EMB_DIM, (dim + 1) // 2) for dim in self.cat_dims
    ]
    logging.info('Cat feat emb dims: %s', cat_emb_dims)
    self.cat_feat_embs = [
        layers.Embedding(input_dim=idim, output_dim=odim)
        for idim, odim in zip(self.cat_dims, cat_emb_dims)
    ]
    self.encoder = layers.LSTM(
        FLAGS.fixed_lstm_hidden, return_state=True, time_major=True)
    self.decoder = layers.LSTM(
        FLAGS.fixed_lstm_hidden, return_sequences=True, time_major=True)
    self.output_layer = keras.Sequential()
    self.output_layer.add(layers.ReLU())
    self.output_layer.add(layers.Dense(FLAGS.fixed_lstm_hidden, use_bias=True))
    self.output_layer.add(layers.ReLU())
    self.output_layer.add(layers.Dense(output_dim, use_bias=True))

  def assemble_feats(self, feats):
    feats_cont = feats[0]  # t x d
    feats_cat = feats[1]  # [t, t]
    feats_emb = [
        emb(feat) for emb, feat in zip(self.cat_feat_embs, feats_cat)  # t x e
    ]
    all_feats = feats_emb + [feats_cont]  # [t x *]
    all_feats = tf.concat(all_feats, axis=-1)  # t x d
    return all_feats

  @tf.function
  def call(self, feats, y_prev, nid):
    """Call function.

    Args:
        feats: t x d, t
        y_prev: t x b
        nid: b

    Returns:
    final output tensors
    """
    feats = self.assemble_feats(feats)  # t x d
    y_prev = tf.expand_dims(y_prev, -1)  # t/2 x b x 1

    feats = tf.expand_dims(feats, 1)  # t x 1 x d
    feats = tf.repeat(feats, repeats=nid.shape[0], axis=1)  # t x b x d

    feats_prev = feats[:FLAGS.hist_len]  # t/2 x b x d
    feats_futr = feats[FLAGS.hist_len:]  # t/2 x b x d

    enc_inp = tf.concat([y_prev, feats_prev], axis=-1)  # t/2 x b x D'

    _, h, c = self.encoder(inputs=enc_inp)  # b x h
    output = self.decoder(inputs=feats_futr, initial_state=(h, c))  # t x b x h
    output = self.output_layer(output)  # t x b x o

    return output

  @tf.function
  def train_step(self, feats, y_obs, nid, optimizer):
    """One step of training.

    Args:
        feats: t x d, t
        y_obs: t x b
        nid: b
        optimizer: nothing

    Returns:
    final output tensors
    """
    loss_func = loss_functions.LOSS_DICT[FLAGS.loss_function]
    with tf.GradientTape() as tape:
      pred = self(feats, y_obs[:FLAGS.hist_len], nid)  # t x 1
      loss = loss_func(pred, y_obs[FLAGS.hist_len:])

    grads = tape.gradient(loss, self.trainable_variables)
    optimizer.apply_gradients(zip(grads, self.trainable_variables))

    logging.info('# Parameters in model %d',
                 np.sum([np.prod(v.shape) for v in self.trainable_variables]))

    return loss

  def eval(self, data, mode):
    iterator = data.tf_dataset(mode=mode)
    level_dict = data.tree.levels
    hist_len = FLAGS.hist_len
    pred_len = FLAGS.test_pred

    all_y_true = None
    all_y_pred = None

    def set_or_concat(A, B):
      if A is None:
        return B
      return np.concatenate((A, B), axis=0)

    loss_func = loss_functions.LOSS_DICT[FLAGS.loss_function]
    all_test_loss = 0
    all_test_num = 0
    for feats, y_obs, nid in iterator:
      assert y_obs.numpy().shape[0] == hist_len + pred_len
      assert feats[0].numpy().shape[0] == hist_len + pred_len

      y_pred = self(feats, y_obs[:hist_len], nid)
      test_loss = loss_func(y_pred, y_obs[hist_len:])  # t x 1
      test_loss = test_loss.numpy()
      all_test_loss += test_loss * y_pred.shape[1]
      all_test_num += y_pred.shape[1]

      y_pred = y_pred.numpy()
      # Assuming predictions are positive

      y_true = y_obs[hist_len:].numpy()

      all_y_pred = set_or_concat(all_y_pred, y_pred)
      all_y_true = set_or_concat(all_y_true, y_true)

    if FLAGS.output_dim > 1:
      all_y_pred = inference_transform(all_y_pred)
    else:
      t, n, _ = all_y_pred.shape
      all_y_pred = all_y_pred.reshape(t, n)

    results_list = []
    # Compute metrics for all time series together
    results_dict = {}
    results_dict['level'] = 'all'
    for metric in METRICS:
      eval_fn = METRICS[metric]
      results_dict[metric] = eval_fn(all_y_pred, all_y_true)
    results_list.append(results_dict)
    # Compute metrics for individual levels and their mean across levels
    mean_dict = {metric: [] for metric in METRICS}

    for d in level_dict:
      results_dict = {}
      sub_pred = all_y_pred[:, level_dict[d]]
      sub_true = all_y_true[:, level_dict[d]]
      for metric in METRICS:
        eval_fn = METRICS[metric]
        eval_val = eval_fn(sub_pred, sub_true)
        results_dict[metric] = eval_val
        mean_dict[metric].append(eval_val)
      results_dict['level'] = d
      results_list.append(results_dict)
    # Compute the mean result of all the levels
    for metric in mean_dict:
      mean_dict[metric] = np.mean(mean_dict[metric])
    mean_dict['level'] = 'mean'
    results_list.append(mean_dict)

    df = pd.DataFrame(data=results_list)
    df.set_index('level', inplace=True)
    logging.info(tabulate(df, headers='keys', tablefmt='psql'))
    logging.info('Loss: %f', test_loss)

    return df, (all_y_pred, all_y_true), all_test_loss / all_test_num


def mape(y_pred, y_true):
  abs_diff = np.abs(y_pred - y_true).flatten()
  abs_val = np.abs(y_true).flatten()
  idx = np.where(abs_val > EPS)
  mpe = np.mean(abs_diff[idx] / abs_val[idx])
  return mpe


def wape(y_pred, y_true):
  abs_diff = np.abs(y_pred - y_true)
  abs_val = np.abs(y_true)
  wpe = np.sum(abs_diff) / (np.sum(abs_val) + EPS)
  return wpe


def smape(y_pred, y_true):
  abs_diff = np.abs(y_pred - y_true)
  abs_mean = (np.abs(y_true) + np.abs(y_pred)) / 2
  smpe = np.mean(abs_diff / (abs_mean + EPS))
  return smpe


def quantile(y_pred, y_true):
  qt = FLAGS.quantile
  diff = y_true - y_pred
  loss = np.zeros(shape=diff.shape)
  loss[diff >= 0] = diff[diff >= 0] * qt
  loss[diff < 0] = -diff[diff < 0] * (1 - qt)
  return 2 * np.sum(loss) / np.sum(np.abs(y_true) + EPS)


def rmse(y_pred, y_true):
  mse = np.square(y_pred - y_true)
  return np.sqrt(mse.mean())


METRICS = {
    'mape': mape,
    'wape': wape,
    'smape': smape,
    'rmse': rmse,
    'quantile': quantile
}
