import numpy as np
import math

from scipy.spatial.distance import pdist

from datasets.dataset_loader import DatasetLoader
from models.fused_ann import FusedANN
from models.range_ann import RangeANN

def compute_delta_sigma(vectors, attributes):
    sample_size = min(10000, len(vectors))
    sample_indices = np.random.choice(len(vectors), sample_size, replace=False)
    sampled_vectors = vectors[sample_indices]
    sampled_attributes = attributes[sample_indices]
    delta_max = np.max(pdist(sampled_vectors))
    sigma_min = np.min(pdist(sampled_attributes) + 1e-8)
    return delta_max, sigma_min

def set_alpha_beta(delta_max, sigma_min, d, m, epsilon_f):
    beta = (delta_max / epsilon_f) + 1e-3
    factor = beta * delta_max / (sigma_min * math.sqrt(d/m))
    alpha_lower_bound = factor * (1 + (epsilon_f * beta / delta_max))
    alpha = alpha_lower_bound + 1e-3
    return alpha, beta

def build_fused_index(dataset_name, ann_algorithm, epsilon_f=0.1):
    vectors, attributes = DatasetLoader.load_dataset(dataset_name)
    d = vectors.shape[1]
    m = attributes.shape[1]

    # Compute δ_max and σ_min
    delta_max, sigma_min = compute_delta_sigma(vectors, attributes)

    # Set α and β based on paper conditions
    alpha, beta = set_alpha_beta(delta_max, sigma_min, d, m, epsilon_f)

    fused_ann = FusedANN(ann_algorithm, alpha, beta)
    fused_ann.build_single(vectors, attributes)

    print(f"Computed parameters: alpha={alpha:.4f}, beta={beta:.4f}")

    return fused_ann

def build_range_index(dataset_name, ann_algorithm, alpha=10.0, beta=2.0):
    vectors, attributes = DatasetLoader.load_dataset(dataset_name)
    range_ann = RangeANN(ann_algorithm, alpha, beta)
    range_ann.build_range_index(vectors, attributes)
    return range_ann