#%%
import numpy as np
# %%
"""
Reference:
[1] https://github.com/vanderschaarlab/hyperimpute/blob/main/src/hyperimpute/plugins/utils/metrics.py
"""

#%%
def SMAPE(train_dataset, imputed):
    """continuous"""
    C = train_dataset.num_continuous_features
    original = train_dataset.raw_data.values[:, :C]
    original = original[train_dataset.mask[:, :C] == 1]
    
    imputation = imputed.values[:, :C]
    imputation = imputation[train_dataset.mask[:, :C] == 1]
    
    smape = np.abs(original - imputation)
    smape /= (np.abs(original) + np.abs(imputation)) + 1e-6 # numerical stability
    smape = smape.mean()
    
    """categorical"""
    original = train_dataset.raw_data.values[:, C:]
    original = original[train_dataset.mask[:, C:] == 1]
    
    imputation = imputed.values[:, C:]
    imputation = imputation[train_dataset.mask[:, C:] == 1]
    
    error = 1. - (original == imputation).mean()
    
    return smape, error
#%%