"""
Random Projection Utilities for Influence Function Computation

This module provides utilities for creating random projection matrices and performing
dimension reduction for efficient influence function computation.
"""

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Tuple, Union, Dict, List, Callable

if TYPE_CHECKING:
    from collections.abc import Callable

import warnings
import numpy as np
import torch
from torch import Tensor


class ProjectionType(str, Enum):
    """Types of random projections available."""
    
    NORMAL = "normal"
    RADEMACHER = "rademacher"


class AbstractProjector(ABC):
    """
    Abstract base class for projection matrices.
    
    This class defines the interface for creating and applying random projections
    to high-dimensional feature vectors (typically gradients).
    """

    @abstractmethod
    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: Union[str, ProjectionType],
        device: Union[str, torch.device],
    ) -> None:
        """
        Initialize projector with configuration.
        
        Args:
            feature_dim: Dimension of input features (typically model parameters)
            proj_dim: Dimension after projection
            seed: Random seed for reproducibility
            proj_type: Type of random projection (normal or rademacher)
            device: Computation device
        """
        self.feature_dim = feature_dim
        self.proj_dim = proj_dim
        self.seed = seed
        self.proj_type = proj_type
        self.device = device

    @abstractmethod
    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """
        Apply random projection to features.
        
        Args:
            features: Input features or dictionary of features
            ensemble_id: Unique identifier for ensemble
            
        Returns:
            Projected features
        """
        pass

    @abstractmethod
    def free_memory(self) -> None:
        """Free memory used by the projector."""
        pass


class CpuProjector(AbstractProjector):
    """
    CPU-based random projector implementation.
    
    This projector generates random projection matrices and applies them to features
    using CPU computation. Suitable for smaller models and datasets.
    """

    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: Union[str, ProjectionType],
        device: torch.device,
        block_size: int = 100,
        dtype: torch.dtype = torch.float32,
        ensemble_id: int = 0,
    ) -> None:
        """
        Initialize CPU projector.
        
        Args:
            feature_dim: Dimension of input features
            proj_dim: Dimension after projection
            seed: Random seed
            proj_type: Type of random projection
            device: Computation device
            block_size: Block size for processing
            dtype: Data type for computation
            ensemble_id: Ensemble identifier
        """
        super().__init__(feature_dim, proj_dim, seed, proj_type, device)
        self.block_size = block_size
        self.dtype = dtype
        self.ensemble_id = ensemble_id
        
        # Generate projection matrix
        self._generate_projection_matrix()

    def _generate_projection_matrix(self) -> None:
        """Generate random projection matrix."""
        torch.manual_seed(self.seed + self.ensemble_id)
        
        if self.proj_type == ProjectionType.NORMAL:
            # Gaussian random matrix
            self.projection_matrix = torch.randn(
                self.feature_dim, self.proj_dim, 
                device=self.device, dtype=self.dtype
            )
        elif self.proj_type == ProjectionType.RADEMACHER:
            # Rademacher random matrix (entries are ±1)
            self.projection_matrix = torch.randint(
                0, 2, (self.feature_dim, self.proj_dim), 
                device=self.device, dtype=self.dtype
            ) * 2 - 1
        else:
            raise ValueError(f"Unknown projection type: {self.proj_type}")
        
        # Normalize projection matrix
        self.projection_matrix = self.projection_matrix / math.sqrt(self.feature_dim)

    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """
        Apply random projection to features.
        
        Args:
            features: Input features or dictionary of features
            ensemble_id: Ensemble identifier (ignored for CPU projector)
            
        Returns:
            Projected features
        """
        if isinstance(features, dict):
            # Convert dictionary to flattened tensor
            feature_tensor = torch.cat([f.flatten() for f in features.values()])
        else:
            feature_tensor = features.flatten()
        
        # Apply projection
        projected = torch.matmul(feature_tensor, self.projection_matrix)
        return projected

    def free_memory(self) -> None:
        """Free memory used by the projector."""
        if hasattr(self, 'projection_matrix'):
            del self.projection_matrix
        torch.cuda.empty_cache()


class CudaProjector(AbstractProjector):
    """
    CUDA-based random projector implementation.
    
    This projector is optimized for GPU computation and can handle larger models
    and datasets efficiently.
    """

    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: ProjectionType,
        device: str,
        max_batch_size: int,
    ) -> None:
        """
        Initialize CUDA projector.
        
        Args:
            feature_dim: Dimension of input features
            proj_dim: Dimension after projection
            seed: Random seed
            proj_type: Type of random projection
            device: CUDA device string
            max_batch_size: Maximum batch size for processing
        """
        super().__init__(feature_dim, proj_dim, seed, proj_type, device)
        self.max_batch_size = max_batch_size
        
        # Generate projection matrix on GPU
        self._generate_projection_matrix()

    def _generate_projection_matrix(self) -> None:
        """Generate random projection matrix on GPU."""
        torch.manual_seed(self.seed)
        
        if self.proj_type == ProjectionType.NORMAL:
            # Gaussian random matrix on GPU
            self.projection_matrix = torch.randn(
                self.feature_dim, self.proj_dim, 
                device=self.device, dtype=torch.float32
            )
        elif self.proj_type == ProjectionType.RADEMACHER:
            # Rademacher random matrix on GPU
            self.projection_matrix = torch.randint(
                0, 2, (self.feature_dim, self.proj_dim), 
                device=self.device, dtype=torch.float32
            ) * 2 - 1
        else:
            raise ValueError(f"Unknown projection type: {self.proj_type}")
        
        # Normalize projection matrix
        self.projection_matrix = self.projection_matrix / math.sqrt(self.feature_dim)

    def project(
        self,
        features: Union[dict, Tensor],
        ensemble_id: int,
    ) -> Tensor:
        """
        Apply random projection to features using GPU.
        
        Args:
            features: Input features or dictionary of features
            ensemble_id: Ensemble identifier (ignored for CUDA projector)
            
        Returns:
            Projected features
        """
        if isinstance(features, dict):
            # Convert dictionary to flattened tensor
            feature_tensor = torch.cat([f.flatten() for f in features.values()])
        else:
            feature_tensor = features.flatten()
        
        # Move to GPU if needed
        if feature_tensor.device != self.device:
            feature_tensor = feature_tensor.to(self.device)
        
        # Apply projection in batches if needed
        if feature_tensor.size(0) > self.max_batch_size:
            projected = []
            for i in range(0, feature_tensor.size(0), self.max_batch_size):
                batch = feature_tensor[i:i + self.max_batch_size]
                batch_projected = torch.matmul(batch, self.projection_matrix)
                projected.append(batch_projected)
            projected = torch.cat(projected, dim=0)
        else:
            projected = torch.matmul(feature_tensor, self.projection_matrix)
        
        return projected

    def free_memory(self) -> None:
        """Free memory used by the projector."""
        if hasattr(self, 'projection_matrix'):
            del self.projection_matrix
        torch.cuda.empty_cache()


class ChunkedCudaProjector:
    """
    Chunked CUDA projector for very large models.
    
    This projector splits large models into chunks and processes each chunk
    separately to handle memory constraints.
    """

    def __init__(
        self,
        projector_per_chunk: list,
        max_chunk_size: int,
        dim_per_chunk: list,
        feature_batch_size: int,
        proj_max_batch_size: int,
        device: torch.device,
        dtype: torch.dtype,
    ) -> None:
        """
        Initialize chunked CUDA projector.
        
        Args:
            projector_per_chunk: List of projectors for each chunk
            max_chunk_size: Maximum size of each chunk
            dim_per_chunk: Dimensions of each chunk
            feature_batch_size: Batch size for features
            proj_max_batch_size: Maximum batch size for projection
            device: Computation device
            dtype: Data type for computation
        """
        self.projector_per_chunk = projector_per_chunk
        self.max_chunk_size = max_chunk_size
        self.dim_per_chunk = dim_per_chunk
        self.feature_batch_size = feature_batch_size
        self.proj_max_batch_size = proj_max_batch_size
        self.device = device
        self.dtype = dtype
        
        # Allocate input buffers
        self._allocate_input()

    def _allocate_input(self) -> None:
        """Allocate input buffers for processing."""
        self.input_buffers = []
        for chunk_size in self.dim_per_chunk:
            buffer = torch.zeros(
                chunk_size, self.feature_batch_size,
                device=self.device, dtype=self.dtype
            )
            self.input_buffers.append(buffer)

    def free_memory(self) -> None:
        """Free memory used by the projector."""
        if hasattr(self, 'input_buffers'):
            del self.input_buffers
        torch.cuda.empty_cache()

    def dict_project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """
        Project dictionary of features.
        
        Args:
            features: Dictionary of features
            ensemble_id: Ensemble identifier
            
        Returns:
            Projected features
        """
        if isinstance(features, dict):
            # Process each chunk separately
            projected_chunks = []
            start_idx = 0
            
            for i, (projector, chunk_size) in enumerate(zip(self.projector_per_chunk, self.dim_per_chunk)):
                # Extract chunk from features
                chunk_features = {}
                for key, value in features.items():
                    if start_idx < value.numel():
                        chunk_end = min(start_idx + chunk_size, value.numel())
                        chunk_features[key] = value.flatten()[start_idx:chunk_end]
                
                # Project chunk
                chunk_projected = projector.project(chunk_features, ensemble_id)
                projected_chunks.append(chunk_projected)
                
                start_idx += chunk_size
            
            # Concatenate projected chunks
            return torch.cat(projected_chunks, dim=0)
        else:
            # Handle tensor input
            return self.project(features, ensemble_id)

    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """
        Apply chunked projection to features.
        
        Args:
            features: Input features or dictionary of features
            ensemble_id: Ensemble identifier
            
        Returns:
            Projected features
        """
        if isinstance(features, dict):
            return self.dict_project(features, ensemble_id)
        else:
            # Handle tensor input by chunking
            projected_chunks = []
            start_idx = 0
            
            for i, (projector, chunk_size) in enumerate(zip(self.projector_per_chunk, self.dim_per_chunk)):
                chunk_end = min(start_idx + chunk_size, features.numel())
                chunk = features.flatten()[start_idx:chunk_end]
                
                # Pad chunk if necessary
                if chunk.size(0) < chunk_size:
                    padding = torch.zeros(chunk_size - chunk.size(0), device=chunk.device, dtype=chunk.dtype)
                    chunk = torch.cat([chunk, padding])
                
                chunk_projected = projector.project(chunk, ensemble_id)
                projected_chunks.append(chunk_projected)
                
                start_idx += chunk_size
            
            return torch.cat(projected_chunks, dim=0)


def make_random_projector(
    param_shape_list: List,
    feature_batch_size: int,
    proj_dim: int,
    proj_max_batch_size: int,
    device: str,
    proj_seed: int = 0,
    *,
    use_half_precision: bool = True,
) -> Tensor:
    """
    Create a random projector for the given parameters.
    
    Args:
        param_shape_list: List of parameter shapes
        feature_batch_size: Batch size for features
        proj_dim: Projection dimension
        proj_max_batch_size: Maximum batch size for projection
        device: Computation device
        proj_seed: Random seed
        use_half_precision: Whether to use half precision
        
    Returns:
        Random projection matrix
    """
    # Calculate total feature dimension
    total_dim = sum(math.prod(shape) for shape in param_shape_list)
    
    # Set data type
    dtype = torch.float16 if use_half_precision else torch.float32
    
    # Generate random projection matrix
    torch.manual_seed(proj_seed)
    projection_matrix = torch.randn(
        total_dim, proj_dim,
        device=device, dtype=dtype
    )
    
    # Normalize projection matrix
    projection_matrix = projection_matrix / math.sqrt(total_dim)
    
    return projection_matrix


def random_project(
    feature: Union[Dict[str, Tensor], Tensor],
    feature_batch_size: int,
    proj_dim: int,
    proj_max_batch_size: int,
    device: str,
    proj_seed: int = 0,
    *,
    use_half_precision: bool = True,
) -> Callable:
    """
    Create a random projection function.
    
    Args:
        feature: Input features or feature dictionary
        feature_batch_size: Batch size for features
        proj_dim: Projection dimension
        proj_max_batch_size: Maximum batch size for projection
        device: Computation device
        proj_seed: Random seed
        use_half_precision: Whether to use half precision
        
    Returns:
        Random projection function
    """
    # Determine parameter shapes
    if isinstance(feature, dict):
        param_shapes = [f.shape for f in feature.values()]
    else:
        param_shapes = [feature.shape]
    
    # Create projection matrix
    projection_matrix = make_random_projector(
        param_shapes, feature_batch_size, proj_dim, proj_max_batch_size,
        device, proj_seed, use_half_precision=use_half_precision
    )
    
    def _random_project_func(
        feature: Union[Dict[str, Tensor], Tensor],
        ensemble_id: int = 0,
    ) -> Tensor:
        """
        Apply random projection to features.
        
        Args:
            feature: Input features
            ensemble_id: Ensemble identifier (ignored for this implementation)
            
        Returns:
            Projected features
        """
        if isinstance(feature, dict):
            # Flatten dictionary of features
            feature_tensor = torch.cat([f.flatten() for f in feature.values()])
        else:
            feature_tensor = feature.flatten()
        
        # Apply projection
        projected = torch.matmul(feature_tensor, projection_matrix)
        return projected
    
    return _random_project_func 