import numpy as np
import pickle
from typing import Tuple, Optional
import warnings


class ReasoningRouter:
    def __init__(self, n_models: int, dim: int, w_T: float = 0.4, 
                 w_A: float = 0.4, w_C: float = 0.2):
        assert abs(w_T + w_A + w_C - 1.0) < 1e-6, "Weights must sum to 1"
        
        self.n_models = n_models
        self.dim = dim
        self.w_T = w_T
        self.w_A = w_A
        self.w_C = w_C
        
        # Quality parameters
        self.sigma_thinking = np.ones(n_models)  
        self.sigma_answer = np.ones(n_models)  
        self.alpha_causal = np.zeros(n_models)  
        self.composite_scores = np.ones(n_models)
        
        # Latent truth estimates
        self.z_star_thinking = None
        self.z_star_answer = None
        
        # Weights for triangulation
        self.omega_1 = 0.8
        self.omega_2 = 0.2
    
    def fit(self, thinking_embeddings: np.ndarray, answer_embeddings: np.ndarray):
        n_samples, n_models, dim = thinking_embeddings.shape
        assert answer_embeddings.shape == (n_samples, n_models, dim)
                
        self._estimate_marginal_qualities(thinking_embeddings, answer_embeddings)
        self._estimate_latent_truth(thinking_embeddings, answer_embeddings)
        self._estimate_causal_strength(thinking_embeddings, answer_embeddings)
        self._compute_composite_scores()
    
    def _compute_weighted_distances(self, thinking_embeddings: np.ndarray,
                                   answer_embeddings: np.ndarray, omega: float):
        n_samples, n_models, dim = thinking_embeddings.shape
        distances = np.zeros((n_models, n_models))
        
        for i in range(n_models):
            for j in range(n_models):
                if i != j:
                    thinking_dist = np.mean(
                        np.linalg.norm(thinking_embeddings[:, i, :] - 
                                     thinking_embeddings[:, j, :], axis=1) ** 2
                    )
                    
                    answer_dist = np.mean(
                        np.linalg.norm(answer_embeddings[:, i, :] - 
                                     answer_embeddings[:, j, :], axis=1) ** 2
                    )
                    
                    distances[i, j] = omega * thinking_dist + (1 - omega) * answer_dist
        
        return distances
    
    def _estimate_marginal_qualities(self, thinking_embeddings: np.ndarray,
                                   answer_embeddings: np.ndarray):
        n_models = self.n_models
        
        distances_omega1 = self._compute_weighted_distances(
            thinking_embeddings, answer_embeddings, self.omega_1
        )
        distances_omega2 = self._compute_weighted_distances(
            thinking_embeddings, answer_embeddings, self.omega_2
        )
        
        for i in range(n_models):
            other_indices = [j for j in range(n_models) if j != i]
            
            if len(other_indices) < 2:
                warnings.warn("Need at least 3 models for triangulation")
                continue
            
            S_omega1_values = []
            S_omega2_values = []
            
            for j_idx, j in enumerate(other_indices):
                for k in other_indices[j_idx + 1:]:
                    S_omega1 = 0.5 * (distances_omega1[i, j] + distances_omega1[i, k] - 
                                     distances_omega1[j, k])
                    S_omega2 = 0.5 * (distances_omega2[i, j] + distances_omega2[i, k] - 
                                     distances_omega2[j, k])
                    
                    S_omega1_values.append(S_omega1)
                    S_omega2_values.append(S_omega2)
            
            if not S_omega1_values:
                continue
                
            avg_S_omega1 = np.mean(S_omega1_values)
            avg_S_omega2 = np.mean(S_omega2_values)
            
            s_omega1 = (2.0 / self.dim) * avg_S_omega1
            s_omega2 = (2.0 / self.dim) * avg_S_omega2
            
            c1, c2 = self.omega_1, self.omega_2  
            
            denominator = c1 * (1 - c2) - c2 * (1 - c1) 
            
            if abs(denominator) > 1e-8:
                inv_sigma_T = (s_omega1 * (1 - c2) - s_omega2 * (1 - c1)) / denominator
                inv_sigma_A = (s_omega1 * c2 - s_omega2 * c1) / (c2 * (1 - c1) - c1 * (1 - c2))
                
                self.sigma_thinking[i] = max(1.0 / max(inv_sigma_T, 1e-8), 1e-6)
                self.sigma_answer[i] = max(1.0 / max(inv_sigma_A, 1e-8), 1e-6)
            else:
                self.sigma_thinking[i] = 1.0
                self.sigma_answer[i] = 1.0
    
    def _estimate_latent_truth(self, thinking_embeddings: np.ndarray, 
                              answer_embeddings: np.ndarray):
        n_samples = thinking_embeddings.shape[0]
        
        thinking_weights = self.sigma_thinking / self.sigma_thinking.sum()
        answer_weights = self.sigma_answer / self.sigma_answer.sum()
        
        self.z_star_thinking = np.average(
            thinking_embeddings, axis=1, weights=thinking_weights
        ).mean(axis=0)
        
        self.z_star_answer = np.average(
            answer_embeddings, axis=1, weights=answer_weights  
        ).mean(axis=0)
    
    def _estimate_causal_strength(self, thinking_embeddings: np.ndarray,
                                 answer_embeddings: np.ndarray):
        n_samples, n_models, dim = thinking_embeddings.shape
        
        for i in range(n_models):
            thinking_dev = thinking_embeddings[:, i, :] - self.z_star_thinking  
            answer_dev = answer_embeddings[:, i, :] - self.z_star_answer
            
            numerator = np.sum(thinking_dev * answer_dev) 
            denominator = np.sum(thinking_dev * thinking_dev) 
            
            if denominator > 1e-8:
                self.alpha_causal[i] = numerator / denominator
            else:
                self.alpha_causal[i] = 0.0
    
    def _compute_composite_scores(self):
        sigma_T_norm = self.sigma_thinking / self.sigma_thinking.sum()
        sigma_A_norm = self.sigma_answer / self.sigma_answer.sum()
        
        coherence_scores = np.abs(self.alpha_causal) / (1 + np.abs(self.alpha_causal))
        coherence_scores *= np.minimum(sigma_T_norm, sigma_A_norm)
        
        if coherence_scores.sum() > 0:
            coherence_scores = coherence_scores / coherence_scores.sum()
        
        self.composite_scores = (self.w_T * sigma_T_norm + 
                               self.w_A * sigma_A_norm + 
                               self.w_C * coherence_scores)
        
        if self.composite_scores.sum() > 0:
            self.composite_scores = self.composite_scores / self.composite_scores.sum()
    
    def predict(self) -> int:
        return np.argmax(self.composite_scores)
    
    def get_model_rankings(self) -> np.ndarray:
        return np.argsort(self.composite_scores)[::-1]


def generate_synthetic_reasoning_data(n_samples: int = 100, n_models: int = 5, 
                                    dim: int = 64) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(42)  
    
    true_thinking = np.random.randn(dim)
    true_answer = np.random.randn(dim)
    
    thinking_embeddings = np.zeros((n_samples, n_models, dim))
    answer_embeddings = np.zeros((n_samples, n_models, dim))
    
    model_qualities_thinking = np.array([2.0, 1.4, 1.0, 1.6, 1.2]) 
    model_qualities_answer = np.array([1.6, 1.2, 1.8, 1.4, 1.0])  
    causal_strengths = np.array([0.8, 0.3, 0.1, 0.6, 0.2]) 
    
    for sample in range(n_samples):
        for model in range(n_models):
            thinking_noise = np.random.randn(dim) / np.sqrt(2 * model_qualities_thinking[model])
            thinking_embeddings[sample, model, :] = true_thinking + thinking_noise
            
            thinking_deviation = thinking_embeddings[sample, model, :] - true_thinking
            causal_influence = causal_strengths[model] * thinking_deviation
            
            answer_noise = np.random.randn(dim) / np.sqrt(2 * model_qualities_answer[model])
            answer_embeddings[sample, model, :] = true_answer + causal_influence + answer_noise
    
    
    return thinking_embeddings, answer_embeddings


def main():
    # This data is for testing, we will release the full code and dataset upon publication.
    # For now, we generate synthetic data.
    thinking_emb, answer_emb = generate_synthetic_reasoning_data(n_samples=200, n_models=5, dim=64)
    
    n_models = thinking_emb.shape[1]
    dim = thinking_emb.shape[2]
    
    causal_router = ReasoningRouter(n_models=n_models, dim=dim, w_T=0.4, w_A=0.4, w_C=0.2)
    causal_router.fit(thinking_emb, answer_emb)
    print(causal_router.predict())

if __name__ == "__main__":    
    main()
