"""Reimplement TimeGAN-pytorch Codebase.

Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
"Time-series Generative Adversarial Networks,"
Neural Information Processing Systems (NeurIPS), 2019.

Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks

Last updated Date: October 18th 2021
Code author: Zhiwei Zhang (bitzzw@gmail.com)

-----------------------------

data.py

(0) MinMaxScaler: Min Max normalizer
(1) sine_data_generation: Generate sine dataset
(2) real_data_loading: Load and preprocess real data
  - stock_data: https://finance.yahoo.com/quote/GOOG/history?p=GOOG
  - energy_data: http://archive.ics.uci.edu/ml/datasets/Appliances+energy+prediction
(3) load_data: download or generate data
(4): batch_generator: mini-batch generator
"""

import numpy as np
from os.path import dirname, abspath
import ast
from pathlib import Path
import pandas as pd
import math
import csv
from tensorflow.keras.utils import to_categorical


def MinMaxScaler(data):
  """Min Max normalizer.
  
  Args:
    - data: original data
  
  Returns:
    - norm_data: normalized data
  """
  numerator = data - np.min(data, 0)
  denominator = np.max(data, 0) - np.min(data, 0)
  norm_data = numerator / (denominator + 1e-7)
  return norm_data


def sine_data_generation (no, seq_len, dim):
  """Sine data generation.
  
  Args:
    - no: the number of samples
    - seq_len: sequence length of the time-series
    - dim: feature dimensions
    
  Returns:
    - data: generated data
  """  
  # Initialize the output
  data = list()

  # Generate sine data
  for i in range(no):      
    # Initialize each time-series
    temp = list()
    # For each feature
    for k in range(dim):
      # Randomly drawn frequency and phase
      freq = np.random.uniform(0, 0.1)            
      phase = np.random.uniform(0, 0.1)
          
      # Generate sine signal based on the drawn frequency and phase
      temp_data = [np.sin(freq * j + phase) for j in range(seq_len)] 
      temp.append(temp_data)
        
    # Align row/column
    temp = np.transpose(np.asarray(temp))        
    # Normalize to [0,1]
    temp = (temp + 1)*0.5
    # Stack the generated data
    data.append(temp)
                
  return data
    

def filter_triplets(df):
  item_sizes = df.groupby('sid').size()
  good_items = item_sizes.index[item_sizes >= 24]
  df = df[df['sid'].isin(good_items)]

  user_sizes = df.groupby('uid').size()
  good_users = user_sizes.index[user_sizes >= 24]
  df = df[df['uid'].isin(good_users)]
  return df

def densify_index(df):
    umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}
    smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}
    df['uid'] = df['uid'].map(umap)
    df['sid'] = df['sid'].map(smap)
    df = df.sort_values(by=['timestamp'],axis=0)
    return df

def discrete_data_loading(data_name, seq_len):
    output = []
    prev_data = []
    file_path = dirname(dirname(abspath(__file__))) + '/data/steam.json'
    f = open(file_path, 'r', encoding='utf-8')
    for line in f.readlines():
        temp = ast.literal_eval(line)
        date = int(temp['date'][0:4])*360+int(temp['date'][5:7])*30+int(temp['date'][8:10])-723915
        prev_data.append([temp['username'], date, temp['product_id']])
    df = pd.DataFrame(prev_data, columns=['uid', 'timestamp', 'sid'])
    df = filter_triplets(df)
    df = densify_index(df)
    item_list= []
    for name in df.drop_duplicates(['sid'])['sid'].tolist():
      item_list.append(str(name))
    with open('./data/steam_item.csv','w', newline='') as f:
      wr = csv.writer(f)
      wr.writerow(item_list)
      f.close()
    for name in df.drop_duplicates(['uid'])['uid'].tolist(): 
      data = []
      for time, item in zip(df[df['uid']==name]['timestamp'].tolist(), df[df['uid']==name]['sid'].tolist()):
        count = 0
        for i in data:
          if (time==round(i[0])) or (time+1==round(i[0])):
            count += 1
        data.append(([time+0.01*count,item]))
      output.append(data)
    with open('./data/steam_data.csv','w', newline='') as f:
      wr = csv.writer(f)
      wr.writerows(output)
      f.close()

    return output

def real_data_loading(data_name, seq_len):
  """Load and preprocess real-world datasets.
  
  Args:
    - data_name: stock or energy
    - seq_len: sequence length
    
  Returns:
    - data: preprocessed data.
  """  
  assert data_name in ['stock','energy', 'discrete']
  
  if data_name == 'stock':
    ori_data = np.loadtxt(dirname(dirname(abspath(__file__))) + '/data/stock_data.csv', delimiter = ",",skiprows = 1)
  elif data_name == 'energy':
    ori_data = np.loadtxt(dirname(dirname(abspath(__file__))) + '/data/energy_data.csv', delimiter = ",",skiprows = 1)
  import pdb; pdb.set_trace()
  ori_data = ori_data[::-1]
  # Normalize the data
  ori_data = MinMaxScaler(ori_data)
    
  # Preprocess the dataset
  temp_data = []    
  # Cut data by sequence length
  for i in range(0, len(ori_data) - seq_len):
    _x = ori_data[i:i + seq_len]
    temp_data.append(_x)
        
  # Mix the datasets (to make it similar to i.i.d)
  idx = np.random.permutation(len(temp_data))    
  data = []
  for i in range(len(temp_data)):
    data.append(temp_data[idx[i]])
    
  return data


def load_data(opt):
  ## Data loading
  if opt.data_name in ['stock', 'energy']:
    ori_data = real_data_loading(opt.data_name, opt.seq_len)  # list: 3661; [24,6]
  elif opt.data_name == 'sine':
    # Set number of samples and its dimensions
    no, dim = 10000, 5
    ori_data = sine_data_generation(no, opt.seq_len, dim)
  elif opt.data_name == 'discrete':
    ori_data = real_data_loading(opt.data_name, opt.seq_len)
    # ori_data = discrete_data_loading(opt.data_name, opt.seq_len)
  print(opt.data_name + ' dataset is ready.')

  return ori_data


def batch_generator(data, time, batch_size):
  """Mini-batch generator.

  Args:
    - data: time-series data
    - time: time information
    - batch_size: the number of samples in each batch

  Returns:
    - X_mb: time-series data in each batch
    - T_mb: time information in each batch
  """
  no = len(data)
  idx = np.random.permutation(no)
  train_idx = idx[:batch_size]

  X_mb = list(data[i] for i in train_idx)
  T_mb = list(time[i] for i in train_idx)

  return X_mb, T_mb