"""Reimplement TimeGAN-pytorch Codebase.

Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
"Time-series Generative Adversarial Networks,"
Neural Information Processing Systems (NeurIPS), 2019.

Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks

Last updated Date: October 18th 2021
Code author: Zhiwei Zhang (bitzzw@gmail.com)

-----------------------------

visualization_metrics.py

Note: Use PCA or tSNE for generated and original data visualization
"""

# Necessary packages
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf


def visualization (ori_data, generated_data, samp_iter=1):
  """Using PCA or tSNE for generated and original data visualization.
  
  Args:
    - ori_data: original data
    - generated_data: generated synthetic data
    - analysis: tsne or pca
  """  
  # Analysis sample size (for faster computation)
  anal_sample_no = min([1000, len(ori_data)])
  idx = np.random.permutation(len(ori_data))[:anal_sample_no]
    
  # Data preprocessing
  ori_data = np.asarray(ori_data)
  generated_data = np.asarray(generated_data)  
  
  ori_data = ori_data[idx]
  generated_data = generated_data[idx]
  
  no, seq_len, dim = ori_data.shape  
  
  for i in range(anal_sample_no):
    if (i == 0):
      prep_data = np.reshape(np.mean(ori_data[0,:,:], 1), [1,seq_len])
      prep_data_hat = np.reshape(np.mean(generated_data[0,:,:],1), [1,seq_len])
    else:
      prep_data = np.concatenate((prep_data, 
                                  np.reshape(np.mean(ori_data[i,:,:],1), [1,seq_len])))
      prep_data_hat = np.concatenate((prep_data_hat, 
                                      np.reshape(np.mean(generated_data[i,:,:],1), [1,seq_len])))
    
  # Visualization parameter        
  colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)]    

  # Do t-SNE Analysis together       
  prep_data_final = np.concatenate((prep_data, prep_data_hat), axis = 0)
  
  # TSNE anlaysis
  tsne = TSNE(n_components = 2, verbose = 1, perplexity = 40, n_iter = 300)
  tsne_results = tsne.fit_transform(prep_data_final)
    
  # Plotting
  f, ax = plt.subplots(1)
    
  plt.scatter(tsne_results[:anal_sample_no,0], tsne_results[:anal_sample_no,1], 
              c = colors[:anal_sample_no], alpha = 0.2, label = "Original")
  plt.scatter(tsne_results[anal_sample_no:,0], tsne_results[anal_sample_no:,1], 
              c = colors[anal_sample_no:], alpha = 0.2, label = "Synthetic")

  ax.legend()
    
  plt.title('t-SNE plot')
  plt.xlabel('x-tsne')
  plt.ylabel('y_tsne')
  sample_dir = f"./sample_{samp_iter}"
  tf.io.gfile.makedirs(sample_dir)
  plt.savefig(f'{sample_dir}/t-SNE plot.png', dpi=300)