import numpy as np
from typing import Dict, List, Tuple, Optional
from scipy.linalg import orthogonal_procrustes
from sklearn.preprocessing import normalize
from .base_calculator import BaseCalculator
class OrthogonalErrorCalculator(BaseCalculator):
    def __init__(self, normalize_vectors: bool = True, mean_center: bool = True):
        super().__init__()
        self.normalize_vectors = normalize_vectors
        self.mean_center = mean_center
    def preprocess_vectors(self, vectors: np.ndarray) -> np.ndarray:
        processed_vectors = vectors.copy()
        if self.normalize_vectors:
            processed_vectors = normalize(processed_vectors, axis=1, norm='l2')
        if self.mean_center:
            processed_vectors = processed_vectors - processed_vectors.mean(0)
            if self.normalize_vectors:
                processed_vectors = normalize(processed_vectors, axis=1, norm='l2')
        return processed_vectors
    def find_orthogonal_transformation(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        min_samples = min(X.shape[0], Y.shape[0])
        X_subset = X[:min_samples]
        Y_subset = Y[:min_samples]
        R, scale = orthogonal_procrustes(X_subset, Y_subset)
        return R, scale
    def calculate_alignment_error(self, X: np.ndarray, Y: np.ndarray, R: np.ndarray, scale: float) -> float:
        min_samples = min(X.shape[0], Y.shape[0])
        X_subset = X[:min_samples]
        Y_subset = Y[:min_samples]
        X_transformed = scale * X_subset @ R
        mse = np.mean((X_transformed - Y_subset) ** 2)
        return mse
    def calculate(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = 1000, random_seed: Optional[int] = None) -> Dict[str, float]:
        if sample_size is not None:
            if random_seed is not None:
                np.random.seed(random_seed)
            total_samples = X.shape[0]
            sample_size = min(sample_size, total_samples)
            random_indices = np.random.choice(total_samples, size=sample_size, replace=False)
            X = X[random_indices]
            Y = Y[random_indices]
        X_processed = self.preprocess_vectors(X)
        Y_processed = self.preprocess_vectors(Y)
        R, scale = self.find_orthogonal_transformation(X_processed, Y_processed)
        mse = self.calculate_alignment_error(X_processed, Y_processed, R, scale)
        min_samples = min(X_processed.shape[0], Y_processed.shape[0])
        X_subset = X_processed[:min_samples]
        Y_subset = Y_processed[:min_samples]
        X_transformed = scale * X_subset @ R
        from scipy.spatial.distance import cosine
        cosine_sim = 1 - cosine(X_transformed.flatten(), Y_subset.flatten())
        frobenius_norm = np.linalg.norm(X_transformed - Y_subset, ord='fro')
        return {
            "orthogonal_error": float(mse),
            "orthogonal_error_mse": float(mse),
            "orthogonal_error_cosine_sim": float(cosine_sim),
            "orthogonal_error_frobenius": float(frobenius_norm),
            "orthogonal_scale": float(scale),
        }
    def calculate_suite(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = 1000, random_seed: Optional[int] = None) -> Dict[str, float]:
        if sample_size is not None:
            if random_seed is not None:
                np.random.seed(random_seed)
            total_samples = X.shape[0]
            sample_size = min(sample_size, total_samples)
            random_indices = np.random.choice(total_samples, size=sample_size, replace=False)
            X = X[random_indices]
            Y = Y[random_indices]
        X_processed = self.preprocess_vectors(X)
        Y_processed = self.preprocess_vectors(Y)
        R, scale = self.find_orthogonal_transformation(X_processed, Y_processed)
        mse = self.calculate_alignment_error(X_processed, Y_processed, R, scale)
        min_samples = min(X_processed.shape[0], Y_processed.shape[0])
        X_subset = X_processed[:min_samples]
        Y_subset = Y_processed[:min_samples]
        X_transformed = scale * X_subset @ R
        from scipy.spatial.distance import cosine
        cosine_sim = 1 - cosine(X_transformed.flatten(), Y_subset.flatten())
        frobenius_norm = np.linalg.norm(X_transformed - Y_subset, ord='fro')
        return {
            "orthogonal_error": float(mse),
            "orthogonal_error_cosine_sim": float(cosine_sim),
            "orthogonal_error_frobenius": float(frobenius_norm),
            "orthogonal_scale": float(scale),
        }
