from abc import ABC, abstractmethod
import numpy as np
from typing import Tuple
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
class BaseRouter(ABC):
    @abstractmethod
    def route(self, embeddings: np.ndarray, **kwargs):
        pass
class FlatRouter(BaseRouter):
    def __init__(self, centroids: np.ndarray, distance_metric: str = "cosine"):
        self.centroids = centroids
        self.distance_metric = distance_metric
    def route(self, embeddings: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        if self.distance_metric == "cosine":
            distances = cosine_distances(embeddings, self.centroids)
        elif self.distance_metric == "euclidean":
            distances = euclidean_distances(embeddings, self.centroids)
        else:
            raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
        expert_ids = np.argmin(distances, axis=1)
        min_distances = np.min(distances, axis=1)
        return expert_ids, min_distances
class CascadeRouter(BaseRouter):
    def __init__(self, tree, distance_metric: str = "cosine"):
        self.tree = tree
        self.distance_metric = distance_metric
    def route(self, embedding: np.ndarray):
        pass
