import numpy as np
from typing import Dict, Optional
from scipy.stats import pearsonr
from sklearn.preprocessing import normalize
from .base_calculator import BaseCalculator
import torch
from loguru import logger
class RSIMCalculator(BaseCalculator):
    def __init__(self, normalize_vectors: bool = True, mean_center: bool = True):
        super().__init__()
        self.normalize_vectors = normalize_vectors
        self.mean_center = mean_center
    def preprocess_vectors(self, vectors: np.ndarray) -> np.ndarray:
        processed_vectors = vectors.copy()
        if self.normalize_vectors:
            processed_vectors = normalize(processed_vectors, axis=1, norm='l2')
        if self.mean_center:
            processed_vectors = processed_vectors - processed_vectors.mean(0)
            if self.normalize_vectors:
                processed_vectors = normalize(processed_vectors, axis=1, norm='l2')
        return processed_vectors
    def calculate_second_order_distances(self, vectors: np.ndarray) -> np.ndarray:
        tensor_vectors = torch.from_numpy(vectors).float()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tensor_vectors = tensor_vectors.to(device)
        dot_product_matrix = torch.matmul(tensor_vectors, tensor_vectors.T)
        return dot_product_matrix.cpu().numpy().flatten()
    def calculate(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = 1000, random_seed: Optional[int] = None) -> Dict[str, float]:
        if sample_size is not None:
            if random_seed is not None:
                np.random.seed(random_seed)
            total_samples = X.shape[0]
            sample_size = min(sample_size, total_samples)
            random_indices = np.random.choice(total_samples, size=sample_size, replace=False)
            X = X[random_indices]
            Y = Y[random_indices]
        X_processed = self.preprocess_vectors(X)
        Y_processed = self.preprocess_vectors(Y)
        X_distances = self.calculate_second_order_distances(X_processed)
        Y_distances = self.calculate_second_order_distances(Y_processed)
        correlation, _ = pearsonr(X_distances, Y_distances)
        return {
            "RSIM_score": float(correlation),
        }
    def calculate_suite(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = 1000, random_seed: Optional[int] = None) -> np.ndarray:
        assert X.shape[0] == Y.shape[0], "X and Y must have the same number of samples"
        assert X.shape[1] == Y.shape[1], "X and Y must have the same number of features"
        logger.info(f"Calculating RSIM suite with sample size {sample_size}")
        if sample_size is not None:
            if random_seed is not None:
                np.random.seed(random_seed)
            total_samples = X.shape[0]
            sample_size = min(sample_size, total_samples)
            random_indices = np.random.choice(total_samples, size=sample_size, replace=False)
            X = X[random_indices]
            Y = Y[random_indices]
        X_processed = self.preprocess_vectors(X)
        Y_processed = self.preprocess_vectors(Y)
        X_distances = self.calculate_second_order_distances(X_processed)
        Y_distances = self.calculate_second_order_distances(Y_processed)
        pearson_corr, _ = pearsonr(X_distances, Y_distances)
        from scipy.stats import spearmanr
        spearman_corr, _ = spearmanr(X_distances, Y_distances)
        from scipy.spatial.distance import cosine
        cosine_sim = 1 - cosine(X_distances, Y_distances)
        mean_diff = np.mean(X_distances) - np.mean(Y_distances)
        std_diff = np.std(X_distances) - np.std(Y_distances)
        result = {
            "pearson_correlation": pearson_corr,
            "spearman_correlation": spearman_corr,
            "cosine_similarity": cosine_sim,
            "mean_dot_product_diff": mean_diff,
            "std_dot_product_diff": std_diff
        }
        return result
