from abc import ABC, abstractmethod
import numpy as np
import torch
from typing import Optional, Dict, Any, Union, overload, TYPE_CHECKING
from loguru import logger
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans
from ..base import VectorMapper
if TYPE_CHECKING:
    from torch.utils.data import DataLoader
class BaseMoEMapper(VectorMapper):
    def __init__(self):
        super().__init__()
        self._is_fitted = False
    def transform_dataset(self, config, dataloader: "DataLoader", cache_path: str) -> None:
        strategy = getattr(config.mapper, 'transform_strategy', 'cluster_then_route')
        if strategy == 'cluster_then_route':
            num_clusters = getattr(config.mapper, 'transform_num_clusters', 16)
            self.cluster_and_transform(
                dataloader=dataloader,
                cache_path=cache_path,
                num_clusters=num_clusters
            )
        elif strategy == 'direct_route':
            self.direct_route_transform(
                dataloader=dataloader,
                cache_path=cache_path
            )
        else:
            raise ValueError(
                f"Unknown transform strategy: '{strategy}'. "
                f"Available: ['cluster_then_route', 'direct_route']"
            )
    @overload
    def fit(self, train_data: np.ndarray, target_data: np.ndarray, 
            reference_indices: np.ndarray) -> None: ...
    @overload
    def fit(self, train_data: "DataLoader") -> None: ...
    def fit(self, train_data: Union[np.ndarray, "DataLoader"], 
            target_data: Optional[np.ndarray] = None, 
            reference_indices: Optional[np.ndarray] = None) -> None:
        if isinstance(train_data, np.ndarray):
            if target_data is None or reference_indices is None:
                raise ValueError("target_data and reference_indices required when using numpy arrays")
            from torch.utils.data import TensorDataset, DataLoader
            logger.info(f"Training with numpy arrays ({len(reference_indices)} reference samples)")
            source_ref = train_data[reference_indices]
            target_ref = target_data[reference_indices]
            dataset = TensorDataset(
                torch.from_numpy(source_ref).float(),
                torch.from_numpy(target_ref).float()
            )
            loader = DataLoader(dataset, batch_size=1024, shuffle=True)
            self._fit_from_loader(loader)
        else:
            logger.info("Training with DataLoader (optimized for large datasets)")
            self._fit_from_loader(train_data)
        self._is_fitted = True
        logger.info("✓ Training completed")
    def fit_multi(self, multi_dataloader) -> None:
        logger.info("Training with MultiMemmapDatasetLoader (recommended for MoE)")
        self._fit_from_loader(multi_dataloader)
        self._is_fitted = True
        logger.info("✓ Training completed via fit_multi() interface")
    @abstractmethod
    def _fit_from_loader(self, train_loader):
        pass
    @abstractmethod
    def transform(self, embeddings: np.ndarray) -> np.ndarray:
        pass
    def get_expert_assignments(self, embeddings: np.ndarray) -> np.ndarray:
        if hasattr(self, 'router') and self.router is not None:
            return self.router.route(embeddings)
        raise NotImplementedError("get_expert_assignments must be implemented by subclass or have a router")
    def _find_nearest_experts_for_clusters(self, cluster_centroids: np.ndarray) -> Dict[int, int]:
        if not hasattr(self, 'clusterer') or self.clusterer is None:
            raise ValueError("Clusterer not initialized. Call fit() first.")
        if not hasattr(self.clusterer, 'centroids') or self.clusterer.centroids is None:
            raise ValueError("Expert cluster centers not available. Call fit() first.")
        expert_centroids = self.clusterer.centroids
        distance_metric = getattr(self, 'distance_metric', 'cosine')
        cluster_to_expert = {}
        for cluster_id, cluster_centroid in enumerate(cluster_centroids):
            if distance_metric == "cosine":
                cluster_norm = np.linalg.norm(cluster_centroid)
                expert_norms = np.linalg.norm(expert_centroids, axis=1)
                if cluster_norm < 1e-10:
                    expert_id = 0
                else:
                    similarities = np.dot(expert_centroids, cluster_centroid) / (expert_norms * cluster_norm + 1e-10)
                    expert_id = int(np.argmax(similarities))
            else:
                distances = np.linalg.norm(expert_centroids - cluster_centroid, axis=1)
                expert_id = int(np.argmin(distances))
            cluster_to_expert[cluster_id] = expert_id
        return cluster_to_expert
    def cluster_and_transform(
        self,
        dataloader: "DataLoader",
        cache_path: str,
        num_clusters: int = 16,
        batch_size: int = 10000
    ) -> None:
        import os
        if not self._is_fitted:
            raise ValueError("Mapper not fitted. Call fit() first.")
        logger.info(f"Cluster-and-transform: {num_clusters} clusters -> {cache_path}")
        clusterer = MiniBatchKMeans(
            n_clusters=num_clusters,
            random_state=42,
            batch_size=min(batch_size, 1024),
            verbose=0,
            compute_labels=False,
            n_init=3
        )
        n_samples = 0
        for batch in tqdm(dataloader, desc="Clustering"):
            src_batch = batch[0] if isinstance(batch, tuple) else batch
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            clusterer.partial_fit(src_batch)
            n_samples += len(src_batch)
        cluster_to_expert = self._find_nearest_experts_for_clusters(clusterer.cluster_centers_)
        logger.info(f"Mapped {num_clusters} clusters to experts")
        os.makedirs(os.path.dirname(cache_path) or '.', exist_ok=True)
        output_memmap = np.lib.format.open_memmap(
            cache_path,
            mode='w+',
            dtype=np.float32,
            shape=(n_samples, self.output_dim)
        )
        processed = 0
        for batch in tqdm(dataloader, desc="Transforming"):
            src_batch = batch[0] if isinstance(batch, tuple) else batch
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            batch_clusters = clusterer.predict(src_batch)
            for cluster_id in range(num_clusters):
                mask = (batch_clusters == cluster_id)
                if not mask.any():
                    continue
                expert_id = cluster_to_expert[cluster_id]
                cluster_samples = src_batch[mask]
                if hasattr(self, 'experts') and expert_id in self.experts:
                    transformed = self.experts[expert_id].transform(cluster_samples)
                else:
                    transformed = self.transform(cluster_samples)
                global_indices = processed + np.where(mask)[0]
                output_memmap[global_indices] = transformed
            processed += len(src_batch)
        output_memmap.flush()
        del output_memmap
        assert processed == n_samples, f"Processed {processed} samples, expected {n_samples}"
        logger.info(f"✓ Saved {n_samples:,} samples to {cache_path}")
    def direct_route_transform(
        self,
        dataloader: "DataLoader",
        cache_path: str
    ) -> None:
        import os
        if not self._is_fitted:
            raise ValueError("Mapper not fitted. Call fit() first.")
        logger.info(f"Direct route transform -> {cache_path}")
        n_samples = 0
        for batch in tqdm(dataloader, desc="Counting samples"):
            src_batch = batch[0] if isinstance(batch, tuple) else batch
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            n_samples += len(src_batch)
        logger.info(f"Total samples: {n_samples:,}")
        os.makedirs(os.path.dirname(cache_path) or '.', exist_ok=True)
        output_memmap = np.lib.format.open_memmap(
            cache_path,
            mode='w+',
            dtype=np.float32,
            shape=(n_samples, self.output_dim)
        )
        processed = 0
        for batch in tqdm(dataloader, desc="Transforming"):
            src_batch = batch[0] if isinstance(batch, tuple) else batch
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            batch_size = len(src_batch)
            expert_ids = self.get_expert_assignments(src_batch)
            for expert_id in np.unique(expert_ids):
                mask = (expert_ids == expert_id)
                if not mask.any():
                    continue
                samples = src_batch[mask]
                if hasattr(self, 'experts') and expert_id in self.experts:
                    transformed = self.experts[expert_id].transform(samples)
                else:
                    transformed = self.transform(samples)
                local_indices = np.where(mask)[0]
                global_indices = processed + local_indices
                output_memmap[global_indices] = transformed
            processed += batch_size
        output_memmap.flush()
        del output_memmap
        assert processed == n_samples, f"Processed {processed} samples, expected {n_samples}"
        logger.info(f"✓ Saved {n_samples:,} samples to {cache_path}")
