import torch
import tensorflow as tf
import os
import logging
import random
import numpy as np

def fix_random_seed(random_num):
    torch.manual_seed(random_num)
    torch.cuda.manual_seed(random_num)
    torch.cuda.manual_seed_all(random_num) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_num)
    random.seed(random_num)

def restore_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    state['conditional_model'].load_state_dict(loaded_state['conditional_model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'])
    state['step'] = loaded_state['step']
    return state

def restore_ER_checkpoint(ckpt_dir, state, device):
    loaded_state = torch.load(ckpt_dir, map_location=device)
    state['encoder'].load_state_dict(loaded_state['encoder'], strict=False)
    state['decoder'].load_state_dict(loaded_state['decoder'], strict=False)
    state['opt_e'].load_state_dict(loaded_state['opt_e'])
    state['opt_r'].load_state_dict(loaded_state['opt_r'])
    return state

def save_checkpoint(ckpt_dir, state):
  saved_state = {
    'optimizer': state['optimizer'].state_dict(),
    'conditional_model': state['conditional_model'].state_dict(),
    'ema': state['ema'].state_dict(),
    'step': state['step']
  }
  torch.save(saved_state, ckpt_dir)

def save_ER_checkpoint(ckpt_dir, state):
  saved_state = {
    'encoder': state['encoder'].state_dict(),
    'decoder': state['decoder'].state_dict(),
    'opt_e' : state['opt_e'].state_dict(),
    'opt_r' : state['opt_r'].state_dict()
  }
  torch.save(saved_state, ckpt_dir)

def train_test_divide (data_x, data_x_hat, data_t, data_t_hat, train_rate = 0.8):
  """Divide train and test data for both original and synthetic data.
  
  Args:
    - data_x: original data
    - data_x_hat: generated data
    - data_t: original time
    - data_t_hat: generated time
    - train_rate: ratio of training data from the original data
  """
  # Divide train/test index (original data)
  no = len(data_x)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
    
  train_x = [data_x[i] for i in train_idx]
  test_x = [data_x[i] for i in test_idx]
  train_t = [data_t[i] for i in train_idx]
  test_t = [data_t[i] for i in test_idx]      
    
  # Divide train/test index (synthetic data)
  no = len(data_x_hat)
  idx = np.random.permutation(no)
  train_idx = idx[:int(no*train_rate)]
  test_idx = idx[int(no*train_rate):]
  
  train_x_hat = [data_x_hat[i] for i in train_idx]
  test_x_hat = [data_x_hat[i] for i in test_idx]
  train_t_hat = [data_t_hat[i] for i in train_idx]
  test_t_hat = [data_t_hat[i] for i in test_idx]
  
  return train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat


def extract_time (data):
  """Returns Maximum sequence length and each sequence length.
  
  Args:
    - data: original data
    
  Returns:
    - time: extracted time information
    - max_seq_len: maximum sequence length
  """
  time = list()
  max_seq_len = 0
  for i in range(len(data)):
    max_seq_len = max(max_seq_len, len(data[i][:,0]))
    time.append(len(data[i][:,0]))
    
  return time, max_seq_len


def random_generator (batch_size, z_dim, T_mb, max_seq_len):
  """Random vector generation.
  
  Args:
    - batch_size: size of the random vector
    - z_dim: dimension of random vector
    - T_mb: time information for the random vector
    - max_seq_len: maximum sequence length
    
  Returns:
    - Z_mb: generated random vector
  """
  Z_mb = list()
  for i in range(batch_size):
    temp = np.zeros([max_seq_len, z_dim])
    temp_Z = np.random.uniform(0., 1, [T_mb[i], z_dim])
    temp[:T_mb[i],:] = temp_Z
    Z_mb.append(temp_Z)
  return Z_mb


def NormMinMax(data):
    """Min-Max Normalizer.

    Args:
      - data: raw data

    Returns:
      - norm_data: normalized data
      - min_val: minimum values (for renormalization)
      - max_val: maximum values (for renormalization)
    """
    min_val = np.min(np.min(data, axis=0), axis=0)
    data = data - min_val  # [3661, 24, 6]

    max_val = np.max(np.max(data, axis=0), axis=0)
    norm_data = data / (max_val + 1e-7)

    return norm_data, min_val, max_val


def scaler(x):

  result = x * 2. - 1.
  return result

def inverse_scaler(x):

  result = (x + 1.) / 2.
  return result

# def restore_data(data, ori_max, ori_min):
#   shape = data.shape[1:]
#   order = np.logical_and(data[:,0,0]<ori_max[0,0], data[:,0,0]>ori_min[0,0])
#   import pdb; pdb.set_trace()
#   for i in range(shape[0]):
#     for j in range(shape[1]):
#       pre_order = np.logical_and(data[:,i,j]<ori_max[i,j], data[:,i,j]>ori_min[i,j])
#       order = np.logical_and(order, pre_order)
#   import pdb; pdb.set_trace()
#   denorm_data = data[order, :, :]
#   return denorm_data

def restore_data(data, ori_max, ori_min):
  shape = data.shape[1:]
  for i in range(shape[0]):
    for j in range(shape[1]):
      data[:,i,j] = torch.clip(data[:,i,j], ori_min[i,j], ori_max[i,j])
  return data