# -*- coding: utf-8 -*-
import numpy as np
from data.split import Splitter

SPLIT_SEED = 123

NUM_FOLDS = 5
VAL_FRACTION = 0.2 # 0.2 of 0.8 = 0.16 of total data


def simple_load_fold(load):
    def load_fold(fold=0, num_folds=NUM_FOLDS, val_fraction=VAL_FRACTION, seed=SPLIT_SEED, **kwargs):   
        df = load(**kwargs)
        num_samples = len(df)
    
        s = Splitter.from_shuffle(num_samples, seed=seed)
        s, test = s.cv(num_folds, fold, return_splitter=True)
        train_idx, val_idx = s.split((1 - val_fraction, val_fraction))
        
        return df.iloc[train_idx], df.iloc[val_idx], df.iloc[test.index]
    
    return load_fold


def check_discrete(a, start=0, stop=5, rtol=0.1, return_digits=False):
    powers = np.arange(start, stop+1)
    d = 10.**-powers[:, None]
    r = a % d
    r = np.minimum(r, d-r)

    m = np.max(r, axis=-1) < rtol*d[:, 0]
    dlevel = np.argmax(m)
    if dlevel == 0 and not m[0]:
        return -1

    return powers[dlevel] if return_digits else d[dlevel, 0]

def convert_fixed_point(df, rtol=1e-3):
    for col in df.columns:
        if df[col].dtype == 'category' or not np.issubdtype(df[col].dtype, np.floating):
            continue
        digits = check_discrete(df[col].values, rtol=rtol, return_digits=True)
        if digits >= 0:
            df[col] = (df[col] * 10**digits).round().astype(int)

    return df