from rpy2.robjects.packages import importr
from rpy2 import robjects
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler


def split_train_val_by_SPlit(data:pd.DataFrame, val_ratio=0.2, numeric_col=None):
    if numeric_col is not None and len(numeric_col) > 0 :
        scaler = StandardScaler()
        data[numeric_col] = scaler.fit_transform(data[numeric_col])
    # r_dataframe = pandas2ri.py2rpy_pandasdataframe(data)
    # print(data.head(5))
    SPlit = importr('SPlit')
    robjects.r('''
                     f <- function(data, val_ratio) {
                        # X = rnorm(n = 100, mean = 0, sd = 1)
                        # Y = rnorm(n = 100, mean = X^2, sd = 1)
                        # data = cbind(X, Y)
                        
                        SPlitIndices = SPlit(data, 
                                       splitRatio = val_ratio,
                                       tolerance = 1e-10, 
                                       nThreads = 8)
                        # start from 1
                        SPlitIndices
                    }
                ''')

    SPlit_func = robjects.r['f']
    with localconverter(robjects.default_converter + pandas2ri.converter):
        val_indexes = SPlit_func(data, val_ratio)

    val_indexes = np.asarray(val_indexes, dtype=int) - 1 # to start from 0

    train_indexes = list(set([i for i in range(len(data))]) - set(val_indexes))

    return train_indexes, val_indexes.tolist()


if __name__ == '__main__':
    pass

