from typing import Tuple, Optional, Dict, Literal
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
import numpy as np
from loguru import logger
import wandb
from pprint import pprint
class GatingMechanism:
    def __init__(self, 
                 centroids: np.ndarray,
                 distance_metric: str = "cosine",
                 temperature: float = 1.0,
                 use_soft_routing: bool = False):
        self.centroids = centroids
        self.distance_metric = distance_metric
        self.temperature = temperature
        self.use_soft_routing = use_soft_routing
        self.num_experts = centroids.shape[0]
        logger.info(f"Initialized GatingMechanism with {self.num_experts} experts, "
                   f"metric={distance_metric}, soft_routing={use_soft_routing}")
    def route_input(self, 
                   input_vector: np.ndarray) -> Tuple[int, float, Optional[np.ndarray], Optional[dict]]:
        input_reshaped = input_vector.reshape(1, -1)
        distances = self._compute_distances(input_reshaped, self.centroids)[0]
        if self.use_soft_routing:
            routing_weights = self._compute_soft_weights(distances)
            expert_id = np.argmax(routing_weights)
            distance = distances[expert_id]
            return expert_id, distance, routing_weights, None
        else:
            expert_id = np.argmin(distances)
            distance = distances[expert_id]
            return expert_id, distance, None, None
    def batch_route(self, 
                   input_vectors: np.ndarray) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[dict]]:
        distances = self._compute_distances(input_vectors, self.centroids)
        if self.use_soft_routing:
            routing_weights = np.array([
                self._compute_soft_weights(dist) for dist in distances
            ])
            expert_ids = np.argmax(routing_weights, axis=1)
            selected_distances = np.array([
                distances[i, expert_ids[i]] for i in range(len(expert_ids))
            ])
            stats = self.get_routing_statistics(input_vectors, expert_ids, selected_distances, routing_weights)
            return expert_ids, selected_distances, routing_weights, stats
        else:
            expert_ids = np.argmin(distances, axis=1)
            selected_distances = np.array([
                distances[i, expert_ids[i]] for i in range(len(expert_ids))
            ])
            stats = self.get_routing_statistics(input_vectors, expert_ids, selected_distances, None)
            return expert_ids, selected_distances, None, stats
    def _compute_distances(self, 
                          input_vectors: np.ndarray, 
                          centroids: np.ndarray) -> np.ndarray:
        if self.distance_metric == "cosine":
            return cosine_distances(input_vectors, centroids)
        elif self.distance_metric == "euclidean":
            return euclidean_distances(input_vectors, centroids)
        else:
            raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
    def _compute_soft_weights(self, distances: np.ndarray) -> np.ndarray:
        similarities = -distances
        scaled_similarities = similarities / self.temperature
        exp_similarities = np.exp(scaled_similarities - np.max(scaled_similarities))
        weights = exp_similarities / np.sum(exp_similarities)
        return weights
    def compute_distance_based_distinguishability(self,
                                                   input_vectors: np.ndarray,
                                                   method: Literal["margin", "ratio", "gap_normalized", "std"] = "margin") -> Dict[str, np.ndarray]:
        distances = self._compute_distances(input_vectors, self.centroids)
        sorted_distances = np.sort(distances, axis=1)
        sorted_indices = np.argsort(distances, axis=1)
        min_distances = sorted_distances[:, 0]
        second_min_distances = sorted_distances[:, 1] if distances.shape[1] > 1 else min_distances
        best_expert_ids = sorted_indices[:, 0]
        second_best_expert_ids = sorted_indices[:, 1] if distances.shape[1] > 1 else sorted_indices[:, 0]
        if method == "margin":
            scores = second_min_distances - min_distances
        elif method == "ratio":
            eps = 1e-10
            scores = second_min_distances / (min_distances + eps)
        elif method == "gap_normalized":
            eps = 1e-10
            scores = (second_min_distances - min_distances) / (min_distances + eps)
        elif method == "std":
            scores = np.std(distances, axis=1)
        else:
            raise ValueError(f"Unsupported method: {method}")
        return {
            "scores": scores,
            "min_distances": min_distances,
            "second_min_distances": second_min_distances,
            "best_expert_ids": best_expert_ids,
            "second_best_expert_ids": second_best_expert_ids,
            "all_distances": distances,
            "method": method
        }
    def get_distance_based_ranking(self,
                                   input_vectors: np.ndarray,
                                   method: Literal["margin", "ratio", "gap_normalized", "std"] = "margin",
                                   return_top_k: Optional[int] = None) -> Dict[str, np.ndarray]:
        result = self.compute_distance_based_distinguishability(input_vectors, method)
        scores = result["scores"]
        sorted_indices = np.argsort(scores)[::-1]
        most_distinguishable = sorted_indices[:return_top_k] if return_top_k else sorted_indices
        least_distinguishable = sorted_indices[-return_top_k:][::-1] if return_top_k else sorted_indices[::-1]
        return {
            "most_distinguishable_indices": most_distinguishable,
            "least_distinguishable_indices": least_distinguishable,
            "scores": scores,
            "sorted_indices": sorted_indices,
            "min_distances": result["min_distances"],
            "second_min_distances": result["second_min_distances"],
            "best_expert_ids": result["best_expert_ids"],
            "second_best_expert_ids": result["second_best_expert_ids"],
            "method": method
        }
    def get_routing_statistics(self, 
                              input_vectors: np.ndarray,
                              expert_ids: np.ndarray,
                              distances: np.ndarray,
                              routing_weights: np.ndarray = None,
                              include_distinguishability: bool = True,
                              distinguishability_method: Literal["margin", "ratio", "gap_normalized", "std"] = "margin") -> dict:
        expert_counts = np.bincount(expert_ids, minlength=self.num_experts)
        distance_stats = {
            "mean_distance": float(np.mean(distances)),
            "std_distance": float(np.std(distances)),
            "min_distance": float(np.min(distances)),
            "max_distance": float(np.max(distances))
        }
        load_balance = {
            "expert_counts": expert_counts.tolist(),
            "min_expert_load": int(np.min(expert_counts)),
            "max_expert_load": int(np.max(expert_counts)),
            "load_std": float(np.std(expert_counts)),
            "load_imbalance": int(np.max(expert_counts) - np.min(expert_counts))
        }
        stats = {
            "num_inputs": len(input_vectors),
            "num_experts": self.num_experts,
            "distance_stats": distance_stats,
            "load_balance": load_balance,
            "routing_type": "soft" if self.use_soft_routing else "hard"
        }
        if routing_weights is not None:
            entropies = []
            for weights in routing_weights:
                eps = 1e-10
                weights_safe = weights + eps
                entropy = -np.sum(weights_safe * np.log(weights_safe))
                entropies.append(entropy)
            stats["routing_entropy"] = {
                "mean": float(np.mean(entropies)),
                "std": float(np.std(entropies)),
                "min": float(np.min(entropies)),
                "max": float(np.max(entropies))
            }
        if include_distinguishability:
            dist_result = self.compute_distance_based_distinguishability(
                input_vectors, method=distinguishability_method
            )
            scores = dist_result["scores"]
            margins = dist_result["second_min_distances"] - dist_result["min_distances"]
            stats["distinguishability"] = {
                "method": distinguishability_method,
                "mean_score": float(np.mean(scores)),
                "std_score": float(np.std(scores)),
                "min_score": float(np.min(scores)),
                "max_score": float(np.max(scores)),
                "median_score": float(np.median(scores)),
                "percentile_25": float(np.percentile(scores, 25)),
                "percentile_75": float(np.percentile(scores, 75)),
                "mean_margin": float(np.mean(margins)),
                "std_margin": float(np.std(margins)),
                "min_margin": float(np.min(margins)),
                "max_margin": float(np.max(margins)),
                "num_high_confidence": int(np.sum(scores >= np.median(scores))),
                "num_low_confidence": int(np.sum(scores < np.median(scores)))
            }
        return stats
    def update_centroids(self, new_centroids: np.ndarray) -> None:
        if new_centroids.shape[0] != self.num_experts:
            raise ValueError(f"Expected {self.num_experts} centroids, "
                           f"got {new_centroids.shape[0]}")
        self.centroids = new_centroids
        logger.info(f"Updated centroids to shape {new_centroids.shape}")
    def set_temperature(self, temperature: float) -> None:
        if temperature <= 0:
            raise ValueError("Temperature must be positive")
        self.temperature = temperature
        logger.info(f"Updated temperature to {temperature}")
    def enable_soft_routing(self, temperature: float = 1.0) -> None:
        self.use_soft_routing = True
        self.temperature = temperature
        logger.info(f"Enabled soft routing with temperature {temperature}")
    def disable_soft_routing(self) -> None:
        self.use_soft_routing = False
        logger.info("Disabled soft routing, using hard routing")
