import numpy as np
from sklearn.preprocessing import QuantileTransformer
from tqdm import tqdm


def synthetic_dataset_function_perlin(
        min_features = 3,
        max_features = 100,
        n_samples = 1000,
        max_classes = 10,
        min_complexity = 0.001,
        max_complexity = 1.0,
        n_octaves = 7,
        categorical_x = True,
    ):

    categorical_perc = get_categorical_perc(categorical_x)
    n_features = get_n_features(min_features, max_features)
    n_classes = get_n_classes(max_classes)
    n_categorical_features = get_n_categorical_features(categorical_perc, n_features)
    n_categorical_classes = get_n_categorical_classes(n_categorical_features)
    complexity = np.random.uniform(min_complexity, max_complexity, size=(1,)).item()

    x = np.random.uniform(size=(n_samples, n_features))

    # x = transform_some_features_to_categorical(x, n_categorical_features, n_categorical_classes)

    y_is = []

    for i in range(n_octaves):

        n_boundaries = 2**(i+1) - 1
        amplitude = complexity**i

        y_values = np.random.uniform(size=(n_features, n_boundaries+2))

        x_big = x * (n_boundaries+2)
        right_indices = np.ceil(x_big).astype(int)
        left_indices = right_indices - 1
        weights = x_big - left_indices

        y_i = np.take(y_values, left_indices) * (1 - weights) + np.take(y_values, right_indices) * weights
        y_i = y_i * amplitude

        y_is.append(y_i)

    y = np.stack(y_is, axis=2)
    y = np.sum(y, axis=2)

    y_weights = np.random.uniform(size=(n_features,))
    z = np.sum(y * y_weights, axis=1)
    
    z = quantile_transform(z)
    z = put_in_buckets(z, n_classes)

    return x, z




def get_n_classes(max_classes: int) -> int:
    return np.random.randint(2, max_classes+1, size=1).item()

def get_categorical_perc(categorical_x: bool) -> float:
    if categorical_x:
        return np.random.uniform(0, 1, size=(1,)).item()
    else:
        return 0
    
def get_depth(min_depth: int, max_depth: int) -> int:
    if min_depth == max_depth:
        return min_depth
    else:
        return np.random.randint(min_depth, max_depth, size=1).item()
    
def get_n_features(min_features: int, max_features: int) -> int:
    if min_features == max_features:
        return min_features
    else:
        return np.random.randint(min_features, max_features, size=1).item()
    
def get_n_categorical_features(categorical_perc: float, n_features: int) -> int:
    return int(categorical_perc * (n_features + 1))

def get_n_categorical_classes(n_categorical_features: int) -> np.ndarray:
    return np.random.geometric(p=0.5, size=(n_categorical_features,)) + 1


def transform_some_features_to_categorical(
        x: np.ndarray, 
        n_categorical_features: int, 
        n_categorical_classes: int
    ) -> np.ndarray:

    if n_categorical_features == 0:
        return x
    
    x_index_categorical = np.random.choice(np.arange(x.shape[1]), size=(n_categorical_features,), replace=False)
    x_categorical = x[:, x_index_categorical]

    quantile_transformer = QuantileTransformer(output_distribution='uniform')
    x_categorical = quantile_transformer.fit_transform(x_categorical)

    for i in range(n_categorical_features):
        x_categorical[:, i] = put_in_buckets(x_categorical[:, i], n_categorical_classes[i]) / n_categorical_classes[i] - 1 / (2 * n_categorical_classes[i])

    x[:, x_index_categorical] = x_categorical

    return x


def quantile_transform(z: np.ndarray) -> np.ndarray:
    quantile_transformer = QuantileTransformer(output_distribution='uniform')
    z = quantile_transformer.fit_transform(z.reshape(-1, 1)).flatten()
    return z


def put_in_buckets(z: np.ndarray, n_classes: int) -> np.ndarray:
    buckets = np.random.uniform(0, 1, size=(n_classes-1,))
    buckets.sort()
    buckets = np.hstack([buckets, 1])
    b = np.argmax(z <= buckets[:, None], axis=0)

    return b


def synthetic_dataset_generator_perlin(**kwargs):

    while True:
        x, y = synthetic_dataset_function_perlin(**kwargs)
        yield x, y



if __name__ == '__main__':

    generator = synthetic_dataset_generator_perlin(
        min_features = 3,
        max_features = 100,
        n_samples = 1000,
        max_classes = 3,
    )

    for _ in tqdm(range(100)):        
        x, y = next(generator)
        pass