import torch
import tensorflow as tf
import os
import logging
import random
import numpy as np

def fix_random_seed(random_num):
    torch.manual_seed(random_num)
    torch.cuda.manual_seed(random_num)
    torch.cuda.manual_seed_all(random_num) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_num)
    random.seed(random_num)

def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['conditional_model'].load_state_dict(loaded_state['conditional_model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['encoder'].load_state_dict(loaded_state['encoder'], strict=False)
    state['decoder'].load_state_dict(loaded_state['decoder'], strict=False)
    state['step'] = loaded_state['step']
    return state

def restore_ER_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['encoder'].load_state_dict(loaded_state['encoder'], strict=False)
    state['decoder'].load_state_dict(loaded_state['decoder'], strict=False)
    state['opt_e'].load_state_dict(loaded_state['opt_e'])
    state['opt_r'].load_state_dict(loaded_state['opt_r'])
    return state

def save_checkpoint(ckpt_dir, state):
  saved_state = {
    'optimizer': state['optimizer'].state_dict(),
    'conditional_model': state['conditional_model'].state_dict(),
    'ema': state['ema'].state_dict(),
    'step': state['step'],
    'encoder': state['encoder'].state_dict(),
    'decoder': state['decoder'].state_dict()
  }
  torch.save(saved_state, ckpt_dir)

def save_ER_checkpoint(ckpt_dir, state):
  saved_state = {
    'encoder': state['encoder'].state_dict(),
    'decoder': state['decoder'].state_dict(),
    'opt_e' : state['opt_e'].state_dict(),
    'opt_r' : state['opt_r'].state_dict()
  }
  torch.save(saved_state, ckpt_dir)

def train_test_divide (data_x, data_x_hat, data_t, data_t_hat, train_rate = 0.8):
  """Divide train and test data for both original and synthetic data.
  
  Args:
    - data_x: original data
    - data_x_hat: generated data
    - data_t: original time
    - data_t_hat: generated time
    - train_rate: ratio of training data from the original data
  """
  # Divide train/test index (original data)
  no = len(data_x)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
    
  train_x = [data_x[i] for i in train_idx]
  test_x = [data_x[i] for i in test_idx]
  train_t = [data_t[i] for i in train_idx]
  test_t = [data_t[i] for i in test_idx]      
    
  # Divide train/test index (synthetic data)
  no = len(data_x_hat)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
  
  train_x_hat = [data_x_hat[i] for i in train_idx]
  test_x_hat = [data_x_hat[i] for i in test_idx]
  train_t_hat = [data_t_hat[i] for i in train_idx]
  test_t_hat = [data_t_hat[i] for i in test_idx]
  
  return train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat

def extract_time (data):
  """Returns Maximum sequence length and each sequence length.
  
  Args:
    - data: original data
    
  Returns:
    - time: extracted time information
    - max_seq_len: maximum sequence length
  """
  time = list()
  max_seq_len = 0
  for i in range(len(data)):
    max_seq_len = max(max_seq_len, len(data[i][:,0]))
    time.append(len(data[i][:,0]))
    
  return time, max_seq_len

def batch_generator(data, time, batch_size):
  """Mini-batch generator.

  Args:
    - data: time-series data
    - time: time information
    - batch_size: the number of samples in each batch

  Returns:
    - X_mb: time-series data in each batch
    - T_mb: time information in each batch
  """
  no = len(data)
  idx = np.random.permutation(no)
  train_idx = idx[:batch_size]

  X_mb = list(data[i] for i in train_idx)
  T_mb = list(time[i] for i in train_idx)

  return X_mb, T_mb