"""Reference:
[1] https://github.com/jsyoon0823/GAIN/blob/master/utils.py
"""
#%%
import torch
import numpy as np
import random

#%%
"""for reproducibility"""
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    random.seed(seed)   

#%%
def sample_binary(m, n, p):
    '''Sample binary random variables.
  
    Args:
        - p: probability of 1
        - rows: the number of rows
        - cols: the number of columns
        
    Returns:
        - binary_random_matrix: generated binary random matrix.
    '''
    random_matrix = np.random.uniform(0., 1., size = [m, n])
    binary_matrix = random_matrix < p
    result = 1.*binary_matrix
    return torch.tensor(result)
#%%
def sample_uniform(m, n):
    '''Sample uniform random variables.
  
    Args:
        - low: low limit
        - high: high limit
        - rows: the number of rows
        - cols: the number of columns
        
    Returns:
        - uniform_random_matrix: generated uniform random matrix.
    '''
    z = np.random.uniform(0., 0.01, size = [m, n])  
    result = torch.tensor(z)
    return result
#%%
def discriminator_loss(D, G, M, New_X, H):
    G_sample = G(New_X, M)

    Hat_New_X = New_X * M + G_sample * (1 - M)
    D_prob = D(Hat_New_X, H)
    
    D_loss = -torch.mean(
        M * torch.log(D_prob + 1e-8) + (1 - M) * torch.log(1. - D_prob + 1e-8)
    )

    return D_loss
#%%
def generator_loss(D, G, X, M, New_X, H, alpha):
    G_sample = G(New_X, M)

    Hat_New_X = New_X * M + G_sample * (1 - M)
    D_prob = D(Hat_New_X, H)
    G_loss1 = -torch.mean((1 - M) * torch.log(D_prob + 1e-8))
    MSE_train_loss = torch.mean((M * New_X - M * G_sample) ** 2) / torch.mean(M)
    
    G_loss = G_loss1 + alpha * MSE_train_loss
    MSE_test_loss = torch.mean(((1 - M) * X - (1 - M) * G_sample) ** 2) / torch.mean(1 - M)
    
    return G_loss, MSE_train_loss, MSE_test_loss
# %%

#%%
def renormalization(norm_data, norm_parameters):
  '''Renormalize data from [0, 1] range to the original range.
  
  Args:
    - norm_data: normalized data
    - norm_parameters: min_val, max_val for each feature for renormalization
  
  Returns:
    - renorm_data: renormalized original data
  '''
  
  min_val = norm_parameters['min_val']
  max_val = norm_parameters['max_val']

  _, dim = norm_data.shape
  renorm_data = norm_data.copy()
    
  for i in range(dim):
    renorm_data[:,i] = renorm_data[:,i] * (max_val[i] + 1e-6)   
    renorm_data[:,i] = renorm_data[:,i] + min_val[i]
    
  return renorm_data
