import numpy as np
import torch
from typing import Dict, Optional
from loguru import logger
from tqdm import tqdm
from .base_calculator import BaseCalculator
class DeltaStructCalculator(BaseCalculator):
    def __init__(self, batch_size: int = 1000, use_gpu: bool = True):
        super().__init__()
        self.batch_size = batch_size
        self.use_gpu = use_gpu
        self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
        if self.use_gpu and torch.cuda.is_available():
            logger.info(f"Using GPU device: {self.device}")
        else:
            logger.info("Using CPU for computation")
    def _compute_distance_matrix_batched(self, vectors: np.ndarray, desc: str = "Computing distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        mem_estimate = self.get_memory_usage_estimate(n_samples)
        required_gb = mem_estimate["estimated_total_mb"] / 1024
        logger.info(f"Estimated memory requirement: {required_gb:.2f} GB for {n_samples}x{n_samples} matrix")
        if required_gb > 8.0:
            logger.info(f"Using hybrid CPU-GPU approach: matrix in CPU memory, batches on GPU")
            return self._compute_distance_matrix_hybrid(vectors, desc)
        try:
            vectors_tensor = torch.from_numpy(vectors).float().to(self.device)
            if self.device.type == 'cuda':
                torch.cuda.empty_cache()
            distance_matrix = torch.zeros(n_samples, n_samples, device=self.device)
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                logger.warning("GPU OOM during matrix allocation, using hybrid approach")
                return self._compute_distance_matrix_hybrid(vectors, desc)
            else:
                raise e
        num_batches_i = (n_samples + self.batch_size - 1) // self.batch_size
        num_batches_j = (n_samples + self.batch_size - 1) // self.batch_size
        total_operations = num_batches_i * num_batches_j
        with tqdm(total=total_operations, desc=desc, unit="batch") as pbar:
            for i in range(0, n_samples, self.batch_size):
                end_i = min(i + self.batch_size, n_samples)
                batch_i = vectors_tensor[i:end_i]
                for j in range(0, n_samples, self.batch_size):
                    end_j = min(j + self.batch_size, n_samples)
                    batch_j = vectors_tensor[j:end_j]
                    batch_i_norm = torch.sum(batch_i ** 2, dim=1, keepdim=True)
                    batch_j_norm = torch.sum(batch_j ** 2, dim=1, keepdim=True)
                    dot_products = torch.mm(batch_i, batch_j.t())
                    squared_distances = batch_i_norm + batch_j_norm.t() - 2 * dot_products
                    distances = torch.sqrt(torch.clamp(squared_distances, min=0))
                    distance_matrix[i:end_i, j:end_j] = distances
                    pbar.update(1)
        return distance_matrix
    def _compute_distance_matrix_hybrid(self, vectors: np.ndarray, desc: str = "Computing distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        distance_matrix = torch.zeros(n_samples, n_samples, device='cpu', dtype=torch.float32)
        vectors_cpu = torch.from_numpy(vectors).float()
        gpu_batch_size = min(self.batch_size // 2, 500)
        num_batches_i = (n_samples + gpu_batch_size - 1) // gpu_batch_size
        num_batches_j = (n_samples + gpu_batch_size - 1) // gpu_batch_size
        total_operations = num_batches_i * num_batches_j
        logger.info(f"Hybrid computation: CPU matrix storage, GPU batch processing (batch_size={gpu_batch_size})")
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()
        with tqdm(total=total_operations, desc=f"{desc} (Hybrid)", unit="batch") as pbar:
            for i in range(0, n_samples, gpu_batch_size):
                end_i = min(i + gpu_batch_size, n_samples)
                batch_i_cpu = vectors_cpu[i:end_i]
                batch_i_gpu = batch_i_cpu.to(self.device)
                for j in range(0, n_samples, gpu_batch_size):
                    end_j = min(j + gpu_batch_size, n_samples)
                    batch_j_cpu = vectors_cpu[j:end_j]
                    batch_j_gpu = batch_j_cpu.to(self.device)
                    try:
                        batch_i_norm = torch.sum(batch_i_gpu ** 2, dim=1, keepdim=True)
                        batch_j_norm = torch.sum(batch_j_gpu ** 2, dim=1, keepdim=True)
                        dot_products = torch.mm(batch_i_gpu, batch_j_gpu.t())
                        squared_distances = batch_i_norm + batch_j_norm.t() - 2 * dot_products
                        distances = torch.sqrt(torch.clamp(squared_distances, min=0))
                        distance_matrix[i:end_i, j:end_j] = distances.cpu()
                    except RuntimeError as e:
                        if "out of memory" in str(e).lower():
                            logger.warning(f"GPU OOM in batch ({i}:{end_i}, {j}:{end_j}), using CPU")
                            batch_i_norm = torch.sum(batch_i_cpu ** 2, dim=1, keepdim=True)
                            batch_j_norm = torch.sum(batch_j_cpu ** 2, dim=1, keepdim=True)
                            dot_products = torch.mm(batch_i_cpu, batch_j_cpu.t())
                            squared_distances = batch_i_norm + batch_j_norm.t() - 2 * dot_products
                            distances = torch.sqrt(torch.clamp(squared_distances, min=0))
                            distance_matrix[i:end_i, j:end_j] = distances
                        else:
                            raise e
                    finally:
                        del batch_j_gpu
                        if self.device.type == 'cuda':
                            torch.cuda.empty_cache()
                    pbar.update(1)
                del batch_i_gpu
                if self.device.type == 'cuda':
                    torch.cuda.empty_cache()
        return distance_matrix
    def _compute_distance_matrix_efficient(self, vectors: np.ndarray, desc: str = "Computing distances") -> torch.Tensor:
        n_samples = vectors.shape[0]
        vectors_tensor = torch.from_numpy(vectors).float().to(self.device)
        if n_samples >= 200:
            logger.info(f"Large matrix detected ({n_samples}x{n_samples}), using block-wise computation")
            return self._compute_distance_matrix_batched(vectors, desc)
        else:
            logger.info(f"Computing {n_samples}x{n_samples} distance matrix directly")
            with tqdm(total=1, desc=desc, unit="matrix") as pbar:
                diff = vectors_tensor.unsqueeze(1) - vectors_tensor.unsqueeze(0)
                distances = torch.norm(diff, dim=2)
                pbar.update(1)
            return distances
    def _compute_ratio_statistics_batched(self, dist_matrix_X: torch.Tensor, dist_matrix_Y: torch.Tensor) -> Dict[str, float]:
        n_samples = dist_matrix_X.shape[0]
        epsilon = 1e-8
        max_ratio = float('-inf')
        min_ratio = float('inf')
        sum_ratios = 0.0
        sum_squared_ratios = 0.0
        count = 0
        frob_diff = 0.0
        frob_Y = 0.0
        block_size = min(self.batch_size, 500)
        with tqdm(total=(n_samples * (n_samples - 1) // 2), desc="Computing ratio statistics", unit="pairs") as pbar:
            for i in range(0, n_samples, block_size):
                end_i = min(i + block_size, n_samples)
                for j in range(i, n_samples, block_size):
                    end_j = min(j + block_size, n_samples)
                    block_X = dist_matrix_X[i:end_i, j:end_j]
                    block_Y = dist_matrix_Y[i:end_i, j:end_j]
                    block_diff = block_X - block_Y
                    frob_diff += torch.sum(block_diff ** 2).item()
                    frob_Y += torch.sum(block_Y ** 2).item()
                    block_Y_safe = block_Y + epsilon
                    block_ratios = block_X / block_Y_safe
                    if i <= j < end_i:
                        mask = torch.ones_like(block_ratios, dtype=torch.bool)
                        diag_start = max(0, j - i)
                        diag_size = min(block_ratios.shape[0], block_ratios.shape[1], end_j - max(i, j))
                        if diag_size > 0:
                            mask[diag_start:diag_start+diag_size, diag_start:diag_start+diag_size].fill_diagonal_(False)
                        valid_ratios = block_ratios[mask]
                    else:
                        valid_ratios = block_ratios.flatten()
                    if len(valid_ratios) > 0:
                        block_max = torch.max(valid_ratios).item()
                        block_min = torch.min(valid_ratios).item()
                        block_sum = torch.sum(valid_ratios).item()
                        block_sum_squared = torch.sum(valid_ratios ** 2).item()
                        max_ratio = max(max_ratio, block_max)
                        min_ratio = min(min_ratio, block_min)
                        sum_ratios += block_sum
                        sum_squared_ratios += block_sum_squared
                        count += len(valid_ratios)
                        pbar.update(len(valid_ratios))
        mean_ratio = sum_ratios / count if count > 0 else 0.0
        variance = (sum_squared_ratios / count - mean_ratio ** 2) if count > 0 else 0.0
        std_ratio = np.sqrt(max(0, variance))
        frobenius_diff = np.sqrt(frob_diff)
        relative_frobenius = frobenius_diff / (np.sqrt(frob_Y) + epsilon)
        return {
            "delta_struct_max": max_ratio,
            "delta_struct_min": min_ratio,
            "delta_struct_mean": mean_ratio,
            "delta_struct_std": std_ratio,
            "frobenius_diff": frobenius_diff,
            "relative_frobenius_diff": relative_frobenius,
        }
    def calculate(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = None) -> Dict[str, float]:
        assert X.shape[0] == Y.shape[0], "X and Y must have the same number of samples"
        if sample_size is not None:
            n_samples = min(sample_size, X.shape[0])
            indices = np.random.choice(X.shape[0], size=n_samples, replace=False)
            X = X[indices]
            Y = Y[indices]
        logger.info(f"Computing delta_struct for {X.shape[0]} samples")
        logger.info("Computing source embedding distance matrix...")
        dist_matrix_X = self._compute_distance_matrix_efficient(X, "Source distances")
        logger.info("Computing target embedding distance matrix...")
        dist_matrix_Y = self._compute_distance_matrix_efficient(Y, "Target distances")
        if X.shape[0] >= 2000:
            logger.info("Using memory-efficient ratio computation for large matrices")
            results = self._compute_ratio_statistics_batched(dist_matrix_X, dist_matrix_Y)
        else:
            epsilon = 1e-8
            dist_matrix_Y_safe = dist_matrix_Y + epsilon
            logger.info("Computing distance matrix ratios...")
            with tqdm(total=1, desc="Computing ratios", unit="matrix") as pbar:
                ratios = dist_matrix_X / dist_matrix_Y_safe
                pbar.update(1)
            mask = ~torch.eye(ratios.shape[0], dtype=torch.bool, device=self.device)
            valid_ratios = ratios[mask]
            max_ratio = torch.max(valid_ratios).item()
            min_ratio = torch.min(valid_ratios).item()
            mean_ratio = torch.mean(valid_ratios).item()
            std_ratio = torch.std(valid_ratios).item()
            frobenius_diff = torch.norm(dist_matrix_X - dist_matrix_Y, p='fro').item()
            relative_frobenius = frobenius_diff / (torch.norm(dist_matrix_Y, p='fro').item() + epsilon)
            results = {
                "delta_struct_max": max_ratio,
                "delta_struct_min": min_ratio,
                "delta_struct_mean": mean_ratio,
                "delta_struct_std": std_ratio,
                "frobenius_diff": frobenius_diff,
                "relative_frobenius_diff": relative_frobenius,
            }
        logger.info(f"Delta struct calculation completed. Max ratio: {results['delta_struct_max']:.4f}")
        return results
    def calculate_suite(self, X: np.ndarray, Y: np.ndarray, sample_size: Optional[int] = None) -> Dict[str, float]:
        return self.calculate(X, Y, sample_size)
    def get_memory_usage_estimate(self, n_samples: int) -> Dict[str, float]:
        matrix_size_bytes = n_samples * n_samples * 4
        matrix_size_mb = matrix_size_bytes / (1024 * 1024)
        total_mb = matrix_size_mb * 3
        return {
            "single_matrix_mb": matrix_size_mb,
            "estimated_total_mb": total_mb,
            "recommended_batch_size": max(100, int(np.sqrt(500 * 1024 * 1024 / 4 / n_samples)))
        }
