"""Time-series Generative Adversarial Networks (TimeGAN) Codebase.

Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar, 
"Time-series Generative Adversarial Networks," 
Neural Information Processing Systems (NeurIPS), 2019.

Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks

Last updated Date: April 24th 2020
Code author: Jinsung Yoon (jsyoon0823@gmail.com)

-----------------------------

utils.py

(1) train_test_divide: Divide train and test data for both original and synthetic data.
(2) extract_time: Returns Maximum sequence length and each sequence length.
(3) rnn_cell: Basic RNN Cell.
(4) random_generator: random vector generator
(5) batch_generator: mini-batch generator
"""

## Necessary Packages
import numpy as np
import tensorflow as tf


def train_test_divide (data_x, data_x_hat, data_t, data_t_hat, train_rate = 0.8):
  """Divide train and test data for both original and synthetic data.
  
  Args:
    - data_x: original data
    - data_x_hat: generated data
    - data_t: original time
    - data_t_hat: generated time
    - train_rate: ratio of training data from the original data
  """
  # Divide train/test index (original data)
  no = len(data_x)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
    
  train_x = [data_x[i] for i in train_idx]
  test_x = [data_x[i] for i in test_idx]
  train_t = [data_t[i] for i in train_idx]
  test_t = [data_t[i] for i in test_idx]      
    
  # Divide train/test index (synthetic data)
  no = len(data_x_hat)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
  
  train_x_hat = [data_x_hat[i] for i in train_idx]
  test_x_hat = [data_x_hat[i] for i in test_idx]
  train_t_hat = [data_t_hat[i] for i in train_idx]
  test_t_hat = [data_t_hat[i] for i in test_idx]
  
  return train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat


def extract_time (data):
  """Returns Maximum sequence length and each sequence length.
  
  Args:
    - data: original data
    
  Returns:
    - time: extracted time information
    - max_seq_len: maximum sequence length
  """
  time = list()
  max_seq_len = 0
  for i in range(len(data)):
    max_seq_len = max(max_seq_len, len(data[i][:,0]))
    time.append(len(data[i][:,0]))
    
  return time, max_seq_len


def rnn_cell(module_name, hidden_dim):
  """Basic RNN Cell.
    
  Args:
    - module_name: gru, lstm, or lstmLN
    
  Returns:
    - rnn_cell: RNN Cell
  """
  assert module_name in ['gru','lstm','lstmLN']
  
  # GRU
  if (module_name == 'gru'):
    rnn_cell = tf.nn.rnn_cell.GRUCell(num_units=hidden_dim, activation=tf.nn.tanh)
  # LSTM
  elif (module_name == 'lstm'):
    rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim, activation=tf.nn.tanh)
  # LSTM Layer Normalization
  elif (module_name == 'lstmLN'):
    rnn_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(num_units=hidden_dim, activation=tf.nn.tanh)
  return rnn_cell


def random_generator (batch_size, z_dim, T_mb, max_seq_len):
  """Random vector generation.
  
  Args:
    - batch_size: size of the random vector
    - z_dim: dimension of random vector
    - T_mb: time information for the random vector
    - max_seq_len: maximum sequence length
    
  Returns:
    - Z_mb: generated random vector
  """
  Z_mb = list()
  for i in range(batch_size):
    temp = np.zeros([max_seq_len, z_dim])
    temp_Z = np.random.uniform(0., 1, [T_mb[i], z_dim])
    temp[:T_mb[i],:] = temp_Z
    Z_mb.append(temp_Z)
  return Z_mb


def batch_generator(data, time, batch_size):
  """Mini-batch generator.
  
  Args:
    - data: time-series data
    - time: time information
    - batch_size: the number of samples in each batch
    
  Returns:
    - X_mb: time-series data in each batch
    - T_mb: time information in each batch
  """
  no = len(data)
  idx = np.random.permutation(no)
  train_idx = idx[:batch_size]     
            
  X_mb = list(data[i] for i in train_idx)
  T_mb = list(time[i] for i in train_idx)
  
  return X_mb, T_mb
 
import pickle
import matplotlib.pyplot as plt
import numpy as np
import random
from copy import deepcopy
from collections import OrderedDict

import torch
from torch.utils.data import Dataset


import logging

def crop_data(driving_cycle, chunk_size):

    chunks = []
    for i in range(0, len(driving_cycle), chunk_size):
        chunk = driving_cycle[i:i+chunk_size]
        if len(chunk) == chunk_size:
            if not if_data_is_DC(chunk):
                chunks.append(chunk)

    return chunks

def crop_data_repeated(driving_cycle, chunk_size, slide_length = 100):
    chunks = []
    for i in range(0, len(driving_cycle) - chunk_size + 1, slide_length):
        chunk = driving_cycle[i:i+chunk_size]
        if not if_data_is_DC(chunk):
            chunks.append(chunk)
    return chunks


def symmetric_padding(data, window_size=21):
    """
    Apply symmetric padding to the data for edge handling.

    Parameters:
        data (array_like): Input data.
        window_size (int): Size of the moving average window.

    Returns:
        array_like: Padded data.
    """
    # Calculate the number of points to pad on each side
    pad_width = window_size // 2
    
    # Pad the data symmetrically
    padded_data = np.pad(data, (pad_width, pad_width), mode='reflect')
    
    return padded_data

def moving_average(data, window_size=21):
    """
    Apply moving average smoothing to the given data.

    Parameters:
        data (array_like): Input data to be smoothed.
        window_size (int): Size of the moving average window.

    Returns:
        array_like: Smoothed data.
    """
    # Apply symmetric padding to the data
    padded_data = symmetric_padding(data, window_size)
    
    # Define the kernel for the moving average
    kernel = np.ones(window_size) / window_size
    
    # Apply the moving average filter
    smoothed_data = np.convolve(padded_data, kernel, mode='valid')
    
    return smoothed_data

def shuffle_list(input_list):
    """
    Return a shuffled version of the input list.

    Parameters:
        input_list (list): Input list to be shuffled.

    Returns:
        list: Shuffled list.
    """
    shuffled_list = input_list[:]  # Make a copy of the input list
    random.shuffle(shuffled_list)  # Shuffle the copy
    return shuffled_list

def shuffle_dataset(dataset):
    # 获取数据集的大小
    dataset_size = len(dataset)
    
    # 生成一个随机排列的索引
    indices = torch.randperm(dataset_size)
    
    # 使用随机排列的索引重新排序数据集
    shuffled_dataset = [dataset[i] for i in indices]
    
    return shuffled_dataset

def draw_figure(data, label='label'):

    plt.plot(data, label=f'{label}')



def show_figure():

    plt.title('vs vs. time for Each Trip')
    plt.xlabel('Time')
    plt.ylabel('vs')
    plt.legend()
    plt.show()

def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)

    return logger

def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag

def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)

def get_labels(dataset):
    labels = []
    for x,y in dataset:
        labels.append(y.item())

    return labels

def if_nan(dataset):
    for i, data in enumerate(dataset):
        x, _ = data  # 假设数据集中的每个数据项都是一个元组，其中 x 是要检查的张量
        if torch.isnan(x).any():
            return True
    return False

def if_data_is_DC(data):

    max_val = max(data)
    min_val = min(data)
    if max_val == 0:
        return True
    if (max_val - min_val) / max_val <= 0.1:
        return True
    return False


class MyDataset(Dataset):
    def __init__(self, data=None, labels=None):

        self.data = [self.map_to_range(torch.tensor(item, dtype=torch.float32)).unsqueeze(0) for item in data] if data is not None else []
        self.labels = [torch.tensor(label, dtype=torch.int) for label in labels] if labels is not None else []


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]

        return x, y

    def add_sample(self, data, label):
        self.data.append(data)

        self.labels.append(label)


    def map_to_range(self, tensor):
        # 找到最大和最小值
        min_val = tensor.min().item()
        max_val = tensor.max().item()
        # 将值映射到 (-1, 1) 范围
        mapped_tensor = 2 * (tensor - min_val) / (max_val - min_val) - 1

        return mapped_tensor

    def modify_labels(self, new_labels):
        # 修改数据集中的标签
        self.labels = new_labels

    def remove_data(self, index):
        for index in index_list:
            del self.data[index]
            del self.targets[index]
