import numpy as np
import torch
from typing import Dict, Optional
from loguru import logger
from tqdm import tqdm
from .base_calculator import BaseCalculator
class CosineDeltaStructCalculator(BaseCalculator):
    def __init__(self, batch_size: int = 1000, use_gpu: bool = True):
        super().__init__()
        self.batch_size = batch_size
        self.use_gpu = use_gpu
        self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
        if self.use_gpu and torch.cuda.is_available():
            logger.info(f"Using GPU device: {self.device}")
        else:
            logger.info("Using CPU for computation")
    def _compute_cosine_distance_matrix_batched(self, vectors: np.ndarray, desc: str = "Computing cosine distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        mem_estimate = self.get_memory_usage_estimate(n_samples)
        required_gb = mem_estimate["estimated_total_mb"] / 1024
        logger.info(f"Estimated memory requirement: {required_gb:.2f} GB for {n_samples}x{n_samples} matrix")
        if required_gb > 8.0:
            logger.info(f"Using hybrid CPU-GPU approach: matrix in CPU memory, batches on GPU")
            return self._compute_cosine_distance_matrix_hybrid(vectors, desc)
        try:
            vectors_tensor = torch.from_numpy(vectors).float().to(self.device)
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            cosine_distance_matrix = torch.zeros(n_samples, n_samples, device=self.device)
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                logger.warning("GPU OOM during matrix allocation, using hybrid approach")
                return self._compute_cosine_distance_matrix_hybrid(vectors, desc)
            else:
                raise e
        num_batches_i = (n_samples + self.batch_size - 1) // self.batch_size
        num_batches_j = (n_samples + self.batch_size - 1) // self.batch_size
        total_operations = num_batches_i * num_batches_j
        with tqdm(total=total_operations, desc=desc, unit="batch") as pbar:
            for i in range(0, n_samples, self.batch_size):
                end_i = min(i + self.batch_size, n_samples)
                batch_i = vectors_tensor[i:end_i]
                for j in range(0, n_samples, self.batch_size):
                    end_j = min(j + self.batch_size, n_samples)
                    batch_j = vectors_tensor[j:end_j]
                    batch_i_norm = torch.nn.functional.normalize(batch_i, p=2, dim=1)
                    batch_j_norm = torch.nn.functional.normalize(batch_j, p=2, dim=1)
                    cosine_similarities = torch.mm(batch_i_norm, batch_j_norm.t())
                    cosine_distances = 1.0 - torch.clamp(cosine_similarities, min=-1.0, max=1.0)
                    cosine_distance_matrix[i:end_i, j:end_j] = cosine_distances
                    pbar.update(1)
        return cosine_distance_matrix
    def _compute_cosine_distance_matrix_hybrid(self, vectors: np.ndarray, desc: str = "Computing cosine distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        cosine_distance_matrix = torch.zeros(n_samples, n_samples, device='cpu', dtype=torch.float32)
        vectors_cpu = torch.from_numpy(vectors).float()
        gpu_batch_size = min(self.batch_size // 2, 500)
        num_batches_i = (n_samples + gpu_batch_size - 1) // gpu_batch_size
        num_batches_j = (n_samples + gpu_batch_size - 1) // gpu_batch_size
        total_operations = num_batches_i * num_batches_j
        logger.info(f"Hybrid computation: CPU matrix storage, GPU batch processing (batch_size={gpu_batch_size})")
        with tqdm(total=total_operations, desc=desc, unit="batch") as pbar:
            for i in range(0, n_samples, gpu_batch_size):
                end_i = min(i + gpu_batch_size, n_samples)
                batch_i_cpu = vectors_cpu[i:end_i]
                try:
                    batch_i_gpu = batch_i_cpu.to(self.device)
                    for j in range(0, n_samples, gpu_batch_size):
                        end_j = min(j + gpu_batch_size, n_samples)
                        batch_j_cpu = vectors_cpu[j:end_j]
                        try:
                            batch_j_gpu = batch_j_cpu.to(self.device)
                            batch_i_norm = torch.nn.functional.normalize(batch_i_gpu, p=2, dim=1)
                            batch_j_norm = torch.nn.functional.normalize(batch_j_gpu, p=2, dim=1)
                            cosine_similarities = torch.mm(batch_i_norm, batch_j_norm.t())
                            cosine_distances = 1.0 - torch.clamp(cosine_similarities, min=-1.0, max=1.0)
                            cosine_distance_matrix[i:end_i, j:end_j] = cosine_distances.cpu()
                        except RuntimeError as e:
                            if "out of memory" in str(e).lower():
                                logger.warning(f"GPU OOM in batch ({i}:{end_i}, {j}:{end_j}), using CPU")
                                batch_i_norm = torch.nn.functional.normalize(batch_i_cpu, p=2, dim=1)
                                batch_j_norm = torch.nn.functional.normalize(batch_j_cpu, p=2, dim=1)
                                cosine_similarities = torch.mm(batch_i_norm, batch_j_norm.t())
                                cosine_distances = 1.0 - torch.clamp(cosine_similarities, min=-1.0, max=1.0)
                                cosine_distance_matrix[i:end_i, j:end_j] = cosine_distances
                            else:
                                raise e
                        finally:
                            del batch_j_gpu
                            if self.device.type == 'cuda':
                                torch.cuda.empty_cache()
                        pbar.update(1)
                    del batch_i_gpu
                    if self.device.type == 'cuda':
                        torch.cuda.empty_cache()
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        logger.warning(f"GPU OOM for i-batch {i}:{end_i}, using CPU for entire i-batch")
                        batch_i_norm = torch.nn.functional.normalize(batch_i_cpu, p=2, dim=1)
                        for j in range(0, n_samples, gpu_batch_size):
                            end_j = min(j + gpu_batch_size, n_samples)
                            batch_j_cpu = vectors_cpu[j:end_j]
                            batch_j_norm = torch.nn.functional.normalize(batch_j_cpu, p=2, dim=1)
                            cosine_similarities = torch.mm(batch_i_norm, batch_j_norm.t())
                            cosine_distances = 1.0 - torch.clamp(cosine_similarities, min=-1.0, max=1.0)
                            cosine_distance_matrix[i:end_i, j:end_j] = cosine_distances
                            pbar.update(1)
                    else:
                        raise e
        return cosine_distance_matrix
    def _compute_cosine_distance_matrix_efficient(self, vectors: np.ndarray, desc: str = "Computing cosine distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        vectors_tensor = torch.from_numpy(vectors).float().to(self.device)
        if n_samples >= 200:
            logger.info(f"Large matrix detected ({n_samples}x{n_samples}), using block-wise computation")
            return self._compute_cosine_distance_matrix_batched(vectors, desc)
        else:
            logger.info(f"Computing {n_samples}x{n_samples} cosine distance matrix directly")
            with tqdm(total=1, desc=desc, unit="matrix") as pbar:
                normalized_vectors = torch.nn.functional.normalize(vectors_tensor, p=2, dim=1)
                cosine_similarities = torch.mm(normalized_vectors, normalized_vectors.t())
                cosine_distances = 1.0 - torch.clamp(cosine_similarities, min=-1.0, max=1.0)
                pbar.update(1)
            return cosine_distances
    def _compute_ratio_statistics_batched(self, dist_matrix_X: torch.Tensor, dist_matrix_Y: torch.Tensor) -> Dict[str, float]:
        n_samples = dist_matrix_X.shape[0]
        epsilon = 1e-8
        max_ratio = float('-inf')
        min_ratio = float('inf')
        sum_ratio = 0.0
        sum_squared_ratio = 0.0
        count = 0
        frob_diff = 0.0
        frob_Y = 0.0
        logger.info(f"Computing ratio statistics for {n_samples}x{n_samples} matrix in batches")
        batch_size = min(self.batch_size, 1000)
        total_batches = (n_samples + batch_size - 1) // batch_size
        with tqdm(total=total_batches, desc="Computing ratios", unit="batch") as pbar:
            for i in range(0, n_samples, batch_size):
                end_i = min(i + batch_size, n_samples)
                batch_X = dist_matrix_X[i:end_i, :]
                batch_Y = dist_matrix_Y[i:end_i, :]
                batch_Y_safe = batch_Y + epsilon
                batch_ratios = batch_X / batch_Y_safe
                for local_i in range(batch_ratios.shape[0]):
                    global_i = i + local_i
                    if global_i < n_samples:
                        valid_mask = torch.ones(batch_ratios.shape[1], dtype=torch.bool, device=batch_ratios.device)
                        valid_mask[global_i] = False
                        valid_ratios_row = batch_ratios[local_i][valid_mask]
                        if len(valid_ratios_row) > 0:
                            batch_max = torch.max(valid_ratios_row).item()
                            batch_min = torch.min(valid_ratios_row).item()
                            batch_sum = torch.sum(valid_ratios_row).item()
                            batch_sum_squared = torch.sum(valid_ratios_row ** 2).item()
                            batch_count = len(valid_ratios_row)
                            max_ratio = max(max_ratio, batch_max)
                            min_ratio = min(min_ratio, batch_min)
                            sum_ratio += batch_sum
                            sum_squared_ratio += batch_sum_squared
                            count += batch_count
                diff_batch = batch_X - batch_Y
                frob_diff += torch.sum(diff_batch ** 2).item()
                frob_Y += torch.sum(batch_Y ** 2).item()
                pbar.update(1)
        if count > 0:
            mean_ratio = sum_ratio / count
            variance_ratio = (sum_squared_ratio / count) - (mean_ratio ** 2)
            std_ratio = np.sqrt(max(0, variance_ratio))
        else:
            mean_ratio = 1.0
            std_ratio = 0.0
        frobenius_diff = np.sqrt(frob_diff)
        relative_frobenius = frobenius_diff / (np.sqrt(frob_Y) + epsilon)
        return {
            "cosine_delta_struct_max": max_ratio,
            "cosine_delta_struct_min": min_ratio,
            "cosine_delta_struct_mean": mean_ratio,
            "cosine_delta_struct_std": std_ratio,
            "cosine_frobenius_diff": frobenius_diff,
            "cosine_relative_frobenius_diff": relative_frobenius,
        }
    def calculate(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = None) -> Dict[str, float]:
        if sample_size is not None:
            indices = np.random.choice(X.shape[0], size=min(sample_size, X.shape[0]), replace=False)
            X = X[indices]
            Y = Y[indices]
        logger.info(f"Computing cosine delta_struct for {X.shape[0]} samples")
        logger.info("Computing source embedding cosine distance matrix...")
        cosine_dist_matrix_X = self._compute_cosine_distance_matrix_efficient(X, "Source cosine distances")
        logger.info("Computing target embedding cosine distance matrix...")
        cosine_dist_matrix_Y = self._compute_cosine_distance_matrix_efficient(Y, "Target cosine distances")
        if X.shape[0] >= 2000:
            logger.info("Using memory-efficient ratio computation for large matrices")
            results = self._compute_ratio_statistics_batched(cosine_dist_matrix_X, cosine_dist_matrix_Y)
        else:
            epsilon = 1e-8
            cosine_dist_matrix_Y_safe = cosine_dist_matrix_Y + epsilon
            logger.info("Computing cosine distance matrix ratios...")
            with tqdm(total=1, desc="Computing ratios", unit="matrix") as pbar:
                ratios = cosine_dist_matrix_X / cosine_dist_matrix_Y_safe
                pbar.update(1)
            mask = ~torch.eye(ratios.shape[0], dtype=torch.bool, device=self.device)
            valid_ratios = ratios[mask]
            max_ratio = torch.max(valid_ratios).item()
            min_ratio = torch.min(valid_ratios).item()
            mean_ratio = torch.mean(valid_ratios).item()
            std_ratio = torch.std(valid_ratios).item()
            frobenius_diff = torch.norm(cosine_dist_matrix_X - cosine_dist_matrix_Y, p='fro').item()
            relative_frobenius = frobenius_diff / (torch.norm(cosine_dist_matrix_Y, p='fro').item() + epsilon)
            results = {
                "cosine_delta_struct_max": max_ratio,
                "cosine_delta_struct_min": min_ratio,
                "cosine_delta_struct_mean": mean_ratio,
                "cosine_delta_struct_std": std_ratio,
                "cosine_frobenius_diff": frobenius_diff,
                "cosine_relative_frobenius_diff": relative_frobenius,
            }
        logger.info(f"Cosine delta struct calculation completed. Max ratio: {results['cosine_delta_struct_max']:.4f}")
        return results
    def calculate_suite(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = None) -> Dict[str, float]:
        return self.calculate(X, Y, sample_size)
    def get_memory_usage_estimate(self, n_samples: int) -> Dict[str, float]:
        matrix_size_bytes = n_samples * n_samples * 4
        matrix_size_mb = matrix_size_bytes / (1024 * 1024)
        total_mb = matrix_size_mb * 3
        return {
            "single_matrix_mb": matrix_size_mb,
            "estimated_total_mb": total_mb,
            "recommended_batch_size": max(100, int(np.sqrt(500 * 1024 * 1024 / 4 / n_samples)))
        }
