"""Data-Loader for different time-series datasets."""

import os
import pickle
import sys

from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v2 as tf

import special_tree

FLAGS = flags.FLAGS
sys.modules['special_tree'] = special_tree


class Data:
  """Data loader class."""

  def __init__(self):
    self.read_data()
    self.transform_data()
    logging.info('Data shape: %s', self.ts_data.shape)
    logging.info('Number of nans: %s', np.count_nonzero(np.isnan(self.ts_data)))

  def read_data(self):
    pkl_path = os.path.join(FLAGS.all_data_path, FLAGS.dataset, 'data.pkl')
    with open(pkl_path, 'rb') as fin:
      logging.info('Found pickle. Loading from %s', pkl_path)
      self.tree, self.ts_data, (self.global_cont_feats, self.global_cat_feats,
                                self.global_cat_dims) = pickle.load(fin)
    self.T, self.num_ts = self.ts_data.shape  # pylint: disable=invalid-name

  def compute_weights(self):
    levels = self.tree.levels
    self.w = np.ones(self.num_ts)
    for _, level in levels.items():
      self.w[level] /= len(level)
    self.w /= len(levels)
    assert np.abs(np.sum(self.w) - 1.0) <= 1e-5

  def transform_data(self):
    """Compute the mean of each node."""
    leaf_mat = self.tree.leaf_matrix.T
    num_leaf = np.sum(leaf_mat, axis=0, keepdims=True)
    self.ts_data = self.ts_data / num_leaf
    self.ts_data = np.floor(self.ts_data)

  def train_gen(self):
    """Generator for trainign data."""
    hist_len = FLAGS.hist_len
    pred_len = FLAGS.train_pred
    tot_len = self.T

    num_data = tot_len - (FLAGS.val_windows + FLAGS.test_windows
                         ) * FLAGS.test_pred - 2 * FLAGS.hist_len
    perm = np.random.permutation(num_data)

    for i in perm:
      sub_feat_cont = self.global_cont_feats[i:i + hist_len + pred_len]
      sub_feat_cat = tuple(
          feat[i:i + hist_len + pred_len] for feat in self.global_cat_feats)
      j = np.random.choice(range(self.num_ts), size=FLAGS.batch_size)
      sub_ts = self.ts_data[i:i + hist_len + pred_len, j]
      yield (sub_feat_cont, sub_feat_cat), sub_ts, j  # t x *

  def val_gen(self):
    """Validation generator."""
    hist_len = FLAGS.hist_len
    tot_len = self.T
    pred_len = FLAGS.test_pred

    start_idx = tot_len - (FLAGS.val_windows + FLAGS.test_windows
                          ) * FLAGS.test_pred - FLAGS.hist_len
    end_idx = tot_len - (FLAGS.test_windows +
                         1) * FLAGS.test_pred - FLAGS.hist_len
    for i in range(start_idx, end_idx + 1, pred_len):
      sub_ts = self.ts_data[i:i + hist_len + pred_len]
      sub_feat_cont = self.global_cont_feats[i:i + hist_len + pred_len]
      sub_feat_cat = tuple(
          feat[i:i + hist_len + pred_len] for feat in self.global_cat_feats)
      j = np.arange(self.num_ts)
      yield (sub_feat_cont, sub_feat_cat), sub_ts, j  # t x *

  def test_gen(self):
    """Test generator."""
    hist_len = FLAGS.hist_len
    tot_len = self.T
    pred_len = FLAGS.test_pred

    start_idx = tot_len - FLAGS.test_windows * FLAGS.test_pred - FLAGS.hist_len
    end_idx = tot_len - FLAGS.test_pred - FLAGS.hist_len

    for i in range(start_idx, end_idx + 1, pred_len):
      sub_ts = self.ts_data[i:i + hist_len + pred_len]
      sub_feat_cont = self.global_cont_feats[i:i + hist_len + pred_len]
      sub_feat_cat = tuple(
          feat[i:i + hist_len + pred_len] for feat in self.global_cat_feats)
      j = np.arange(self.num_ts)
      yield (sub_feat_cont, sub_feat_cat), sub_ts, j  # t x *

  def tf_dataset(self, mode):
    """Tensorflow Dataset."""
    if mode == 'train':
      gen_fn = self.train_gen
    elif mode == 'val':
      gen_fn = self.val_gen
    elif mode == 'test':
      gen_fn = self.test_gen

    num_cat_feats = len(self.global_cat_dims)
    output_type = tuple([tf.int32] * num_cat_feats)
    dataset = tf.data.Dataset.from_generator(
        gen_fn,
        (
            (tf.float32, output_type),  # feats
            tf.float32,  # y_obs
            tf.int32,  # id
        ))
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset
