import numpy as np
from counterfactual_explanations.dim_reduction import DimensionalityReduction

def sample_points(X, y, n_points=300):
        rng = np.random.default_rng(0)
        indices = rng.choice(len(X), size=n_points, replace=False)
        return X[indices], y[indices]
    
def get_feature_ranges(input_properties, X):
    feature_details = input_properties.get_feature_details()
    feature_ub = []
    feature_lb = []

    for i, feature in enumerate(feature_details):
        if feature[2][0] == float('-inf'):
            lb = np.min(X[:, i]) - 0.2*(np.max(X[:, i]) - np.min(X[:, i])) 
            feature_lb.append(lb)
        else:
            feature_lb.append(feature[2][0])

        if feature[2][1] == float('inf'):
            ub = np.max(X[:, i]) + 0.2*(np.max(X[:, i]) - np.min(X[:, i])) 
            feature_ub.append(ub)
        else:
            feature_ub.append(feature[2][1])

    return feature_lb, feature_ub

def generate_grid_points(input_properties, X, dim_reduction: DimensionalityReduction = None, factor=1.5):
    if not dim_reduction:
        features_lb, features_ub = get_feature_ranges(input_properties, X)
    else:
        n_dims = dim_reduction.target_dim
        data_encoded = dim_reduction.encode(X)
        
        features_ub = []
        features_lb = []

        for i in range(n_dims):
            lb = np.min(data_encoded[:, i]) - 0.2*(np.max(data_encoded[:, i]) - np.min(data_encoded[:, i])) 
            features_lb.append(lb)

            ub = np.max(data_encoded[:, i]) + 0.2*(np.max(data_encoded[:, i]) - np.min(data_encoded[:, i])) 
            features_ub.append(ub)

    n_points = X.shape[0]
    grid_points = np.zeros_like(X)

    grid_points = np.meshgrid(
        *[np.linspace(features_lb[i], features_ub[i], num=int(factor*n_points**(1/len(features_lb)))) for i in range(len(features_lb))]
    )
    grid_points = np.array([point.flatten() for point in grid_points]).T

    return grid_points

def median_pairwise_distances(X, dim_reduction: DimensionalityReduction = None):
    pairwise_distances = []

    if dim_reduction:
        X = dim_reduction.encode(X)

    for i in range(len(X)):
        for j in range(i + 1, len(X)):
            pairwise_distances.append(np.linalg.norm(X[i] - X[j]))

    return np.median(pairwise_distances)