"""
Functions to generate the synthetic dataset used for training the predictive model, and the population applying at every time step.
"""

from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd

def generate_synthetic_data(
    n_agents,
    n_continuous,
    n_categorical,
    scaler=None,
    rng=None,
    noise_std=0.1,
    continuous_means=None,
    continuous_std_devs=None,
    categorical_probs=None,
    continuous_weights=None,
    feature_ranges=None,  # New parameter for feature min/max ranges
    **kwargs
):
    # Generate or use provided continuous feature means and standard deviations
    if continuous_means is None or continuous_std_devs is None:
        means = rng.uniform(low=0.0, high=1.0, size=n_continuous)
        std_devs = rng.uniform(low=0.1, high=0.25, size=n_continuous)
    else:
        means = continuous_means
        std_devs = continuous_std_devs
        
    # Generate continuous features
    X_cont = np.array([rng.normal(mean, std, n_agents) for mean, std in zip(means, std_devs)]).T

    # Save means and standard deviations before scaling
    original_means = np.mean(X_cont, axis=0)
    original_std_devs = np.std(X_cont, axis=0)

    if feature_ranges is None:
        feature_ranges = []
        for mean, std in zip(means, std_devs):
            # Extend the mean by 3 standard deviations both up and down
            feature_max = mean + 3 * std
            feature_min = mean - 3 * std
            
            # Clamp the min and max values within [0, 1]
            clamped_min = max(0, feature_min)
            clamped_max = min(1, feature_max)

            # Ensure clamped_min is strictly less than clamped_max
            if clamped_min >= clamped_max:
                # Adjust the range to have a small positive width while staying within [0, 1]
                range_width = 0.01
                clamped_min = max(0, mean - range_width / 2)
                clamped_max = min(1, mean + range_width / 2)

            feature_ranges.append([clamped_min, clamped_max])

    # Apply custom scaling for each feature based on provided ranges
    scaled_features = []
    for i, (feature_min, feature_max) in enumerate(feature_ranges):
        feature_scaler = MinMaxScaler(feature_range=(feature_min, feature_max))
        scaled_feature = feature_scaler.fit_transform(X_cont[:, i].reshape(-1, 1)).flatten()
        scaled_features.append(scaled_feature)

    X_cont_scaled = np.array(scaled_features).T

    if scaler is not None:
        X_cont_scaled = scaler.fit_transform(X_cont_scaled)

    # Generate or use provided probabilities for categorical features
    if n_categorical > 0:
        if categorical_probs is None:
            probabilities = rng.uniform(low=0.4, high=0.6, size=n_categorical)
        else:
            probabilities = categorical_probs
        
        X_cat = np.array([rng.binomial(1, p, n_agents) for p in probabilities]).T
        categorical_columns = [f"cat_{i}" for i in range(n_categorical)]
    else:
        X_cat = np.empty((n_agents, 0))
        probabilities = []
        categorical_columns = []

    continuous_columns = [f"cont_{i}" for i in range(n_continuous)]
    X = pd.DataFrame(np.hstack([X_cont_scaled, X_cat]), columns=continuous_columns + categorical_columns)

    if continuous_weights is None:
        # Generate random weights for continuous features
        random_weights = rng.uniform(low=0.1, high=1.0, size=n_continuous)
        continuous_weights = random_weights / random_weights.sum()

    categorical_weights = np.zeros(n_categorical)

    y_continuous = np.dot(X_cont, continuous_weights)
    y_categorical = np.dot(X_cat, categorical_weights) if n_categorical > 0 else 0
    
    noise_std = 0.1
    y = y_continuous + y_categorical + rng.normal(0, noise_std, n_agents)

    ideal_threshold = np.median(y)
    threshold = rng.normal(loc=ideal_threshold, scale=0.05)
    threshold = 0.5
    y_binarized = (y >= threshold).astype(int)
    y_df = pd.DataFrame(y_binarized, columns=["target"])

    return X, y_df, continuous_columns, categorical_columns, means, std_devs, probabilities, continuous_weights, original_means, original_std_devs, feature_ranges