import numpy as np
import time
from api.build_index import build_range_index
from api.query_index import query_range_filter
from models.base_ann import BaseANN
from datasets.dataset_loader import DatasetLoader
from datasets.dataset_loader import DatasetLoader
from utils.transform import Transformer
from utils.indexing import LineIndex
from models.range_ann import CylinderIndex

from utils.transform import Transformer

def adaptive_line_sampling(vectors, attributes, alpha, beta, num_samples=1000):
    np.random.seed(42)
    sampled_indices = np.random.choice(len(vectors), num_samples, replace=False)
    sampled_vectors = vectors[sampled_indices]
    sampled_attrs = attributes[sampled_indices]

    lines = []
    for vec, attr in zip(sampled_vectors, sampled_attrs):
        attr_low = attr - 0.05
        attr_high = attr + 0.05
        start_point = Transformer.single_transform(vec[None], attr_low[None], alpha, beta)[0]
        end_point = Transformer.single_transform(vec[None], attr_high[None], alpha, beta)[0]
        lines.append((start_point, end_point))

    return lines

def run_complete_three_level_indexing_experiment():
    alpha, beta = 10, 2
    vectors, attributes = DatasetLoader.load_dataset("DEEP")

    # Level 1: Adaptive line sampling
    lines = adaptive_line_sampling(vectors, attributes, alpha, beta, num_samples=1000)

    # Level 2: Hierarchical line indexing
    line_idx = LineIndex(lines)

    # Build Cylindrical Indexes (Level 3) for all lines
    cylinder_indexes = []
    for start, end in lines:
        midpoint = (start + end)/2
        # points near midpoint (for simplicity, radius arbitrary)
        points_idxs = np.random.choice(len(vectors), 500, replace=False)
        points = Transformer.single_transform(vectors[points_idxs], attributes[points_idxs], alpha, beta)
        cylinder_idx = CylinderIndex(start, end, points, radius=0.2)
        cylinder_indexes.append(cylinder_idx)

    # Querying
    query_vec, query_attr = vectors[0], attributes[0]
    attr_low, attr_high = query_attr - 0.05, query_attr + 0.05
    query_start = Transformer.single_transform(query_vec[None], attr_low[None], alpha, beta)[0]
    query_end = Transformer.single_transform(query_vec[None], attr_high[None], alpha, beta)[0]

    nearest_line = line_idx.find_nearest_line(query_start, query_end)
    idx_line = lines.index(nearest_line)
    cylinder_index = cylinder_indexes[idx_line]

    results = cylinder_index.query_cylinder(query_start, query_end, query_radius=0.2)

    print("Three-Level Indexing Range Query Results:", results)

def run_range_filter_experiment():
    ann_algorithm = SimpleANN()
    range_ann = build_range_index("DEEP", ann_algorithm)

    query_vectors, query_attrs = DatasetLoader.load_dataset("DEEP_query")

    recalls = []
    times = []
    for i in range(len(query_vectors)):
        attr_low = query_attrs[i] - 0.05
        attr_high = query_attrs[i] + 0.05
        radius = 0.2
        start = time.time()
        results = query_range_filter(range_ann, query_vectors[i], attr_low, attr_high, radius, top_k=10)
        times.append(time.time() - start)
        true_neighbors = np.arange(10)
        recall = len(set(results).intersection(true_neighbors)) / 10
        recalls.append(recall)

    print("Range Filtering - Avg Recall@10:", np.mean(recalls))
    print("Range Filtering - Avg Query Time (ms):", np.mean(times)*1000)

if __name__ == "__main__":
    run_range_filter_experiment()
    run_complete_three_level_indexing_experiment()