import numpy as np
import time
from api.build_index import build_fused_index
from api.query_index import query_single_filter
from models.base_ann import BaseANN
from datasets.dataset_loader import DatasetLoader

class SimpleANN(BaseANN):
    def __init__(self):
        self.vectors = None

    def build_index(self, vectors):
        self.vectors = vectors

    def query(self, vector, top_k):
        dists = np.linalg.norm(self.vectors - vector, axis=1)
        return np.argsort(dists)[:top_k]

def run_single_attribute_experiment():
    ann_algorithm = SimpleANN()
    fused_ann = build_fused_index("SIFT1M", ann_algorithm)

    query_vectors, query_attrs = DatasetLoader.load_dataset("SIFT1M_query")

    recalls = []
    times = []
    for i in range(len(query_vectors)):
        start = time.time()
        results = query_single_filter(fused_ann, query_vectors[i], query_attrs[i], top_k=10)
        times.append(time.time() - start)
        # Simulate ground truth for example:
        true_neighbors = np.arange(10)
        recall = len(set(results).intersection(true_neighbors)) / 10
        recalls.append(recall)

    print("Single Attribute Filtering - Avg Recall@10:", np.mean(recalls))
    print("Single Attribute Filtering - Avg Query Time (ms):", np.mean(times)*1000)

if __name__ == "__main__":
    run_single_attribute_experiment()