import numpy as np
from models.base_ann import BaseANN
from utils.transform import Transformer
from sklearn.neighbors import KDTree


class RangeANN:
    def __init__(self, base_ann: BaseANN, alpha=10.0, beta=2.0):
        self.base_ann = base_ann
        self.alpha = alpha
        self.beta = beta

    def build_range_index(self, vectors, attributes):
        transformed = Transformer.single_transform(vectors, attributes, self.alpha, self.beta)
        self.base_ann.build_index(transformed)

    def query_range(self, vector, attr_low, attr_high, radius, top_k=10):
        line_start = Transformer.single_transform(vector.reshape(1, -1), attr_low.reshape(1, -1), self.alpha, self.beta)[0]
        line_end = Transformer.single_transform(vector.reshape(1, -1), attr_high.reshape(1, -1), self.alpha, self.beta)[0]
        line_center = (line_start + line_end) / 2
        candidates = self.base_ann.query(line_center, top_k * 10)
        # Filter candidates by cylinder radius
        final_candidates = []
        for idx in candidates:
            v = self.base_ann.get_vector(idx)
            distance = np.linalg.norm(np.cross(line_end-line_start, line_start-v)) / np.linalg.norm(line_end-line_start)
            if distance <= radius:
                final_candidates.append((idx, distance))
        final_candidates.sort(key=lambda x: x[1])
        return [idx for idx, dist in final_candidates[:top_k]]

class CylinderIndex:
    def __init__(self, line_start, line_end, points, radius):
        self.line_start = line_start
        self.line_end = line_end
        self.radius = radius
        self.points = points
        self.tree = KDTree(points)

    def query_cylinder(self, query_start, query_end, query_radius):
        query_midpoint = (query_start + query_end) / 2
        idxs = self.tree.query_radius([query_midpoint], r=query_radius)[0]

        # Cylinder filtering
        final_points = []
        for idx in idxs:
            p = self.points[idx]
            dist = np.linalg.norm(np.cross(query_end-query_start, query_start-p)) / np.linalg.norm(query_end-query_start)
            if dist <= query_radius:
                final_points.append(idx)
        return final_points