"""
Unified manager for hooks and projection matrices in MoE knowledge editing
"""

import torch
import torch.nn as nn
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
from dataclasses import dataclass
from tqdm import tqdm

from .logger import get_logger, info, debug, warning, error
from .utils import find_subject_token_positions


# ============================================================================
# Data Classes
# ============================================================================

@dataclass
class ProjectionMatrixCache:
    """Simplified cache structure for projection matrices"""
    matrix: torch.Tensor
    expert_idx: int
    computation_time: str
    matrix_shape: Tuple[int, int]


# ============================================================================
# Hook Management
# ============================================================================

class HookManager:
    """Simple hook manager with basic cleanup"""

    def __init__(self):
        self.hooks: List[torch.utils.hooks.RemovableHandle] = []

    def register(self, module: nn.Module, hook_fn) -> torch.utils.hooks.RemovableHandle:
        """Register a forward hook"""
        handle = module.register_forward_hook(hook_fn)
        self.hooks.append(handle)
        return handle

    def remove_all(self):
        """Remove all registered hooks"""
        for handle in self.hooks:
            handle.remove()
        self.hooks.clear()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.remove_all()


class DownProjCollector:
    """Simple collector for down_proj statistics"""

    def __init__(self, target_layer, expert_indices=None):
        self.target_layer = target_layer
        self.expert_indices = expert_indices or list(range(len(target_layer.experts)))
        self.captured_keys = {}
        self.hook_manager = HookManager()

    def setup_hooks(self):
        """Setup hooks for down_proj layers"""
        for expert_idx in self.expert_indices:
            if expert_idx < len(self.target_layer.experts):
                expert = self.target_layer.experts[expert_idx]
                if hasattr(expert, 'down_proj'):
                    hook_fn = self._create_down_proj_hook(expert_idx)
                    self.hook_manager.register(expert.down_proj, hook_fn)
                    self.captured_keys[expert_idx] = []

    def _create_down_proj_hook(self, expert_idx):
        """Create hook function for specific expert"""
        def hook(module, input, output):
            if expert_idx in self.captured_keys:
                key = input[0].detach().clone()
                self.captured_keys[expert_idx].append(key)
        return hook

    def get_captured_keys(self):
        """Get all captured keys"""
        return self.captured_keys

    def clear(self):
        """Clear captured data"""
        self.captured_keys.clear()
        self.hook_manager.remove_all()


# Backward compatibility functions
def hook_context(hook_manager):
    """Simple context for hook cleanup"""
    return hook_manager


def with_hook_cleanup(func):
    """Simple decorator for hook cleanup"""
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper


# ============================================================================
# Projection Matrix Management
# ============================================================================

class ProjectionMatrixManager:
    """
    Unified projection matrix manager:
    - Collect keys at last subject token positions
    - Use Wikipedia-style datasets
    - Expert-wise null space computation
    - Simplified caching system
    """

    def __init__(self,
                 model,
                 tokenizer,
                 target_layer,
                 num_experts: int,
                 d_intermediate: int,
                 layer_idx: int,
                 device: str = "cuda",
                 cache_dir: str = "./projection_matrix_cache",
                 nullspace_threshold: float = 2e-2,
                 fact_token_strategy: str = "subject_last"):

        self.model = model
        self.tokenizer = tokenizer
        self.target_layer = target_layer
        self.num_experts = num_experts
        self.d_intermediate = d_intermediate
        self.layer_idx = layer_idx
        self.device = device
        self.nullspace_threshold = nullspace_threshold
        self.fact_token_strategy = fact_token_strategy

        # Setup cache directory
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # Initialize components
        self.logger = get_logger()

        # Cached matrices and covariance statistics
        self.projection_matrices: Dict[int, torch.Tensor] = {}
        self.covariance_matrices: Dict[int, torch.Tensor] = {}
        self.sample_counts: Dict[int, int] = {}

        info(f"Projection manager initialized with cache: {self.cache_dir}")
        info(f"Covariance SVD enabled for samples > 2000")

    def get_wikipedia_style_texts(self, num_samples: int) -> List[str]:
        """Load real Wikipedia texts for key collection"""
        local_wiki_path = "/root/autodl-tmp/datasets/wikipedia/20220301.en"

        try:
            if Path(local_wiki_path).exists():
                return self._load_from_local_cache(local_wiki_path, num_samples)
            else:
                return self._load_from_huggingface(num_samples)

        except Exception as e:
            error(f"Failed to load Wikipedia dataset: {e}")
            raise RuntimeError(f"Cannot load Wikipedia dataset: {e}")

    def _load_from_local_cache(self, cache_path: str, num_samples: int) -> List[str]:
        """Load Wikipedia data from local cache"""
        import glob

        try:
            # Import HuggingFace datasets library
            import datasets
            load_from_disk = datasets.load_from_disk
            Dataset = datasets.Dataset
        except ImportError as e:
            raise ImportError(f"Failed to import HuggingFace datasets library: {e}")

        # Find arrow files in the cache directory
        arrow_pattern = f"{cache_path}/**/wikipedia-train-*.arrow"
        arrow_files = glob.glob(arrow_pattern, recursive=True)

        if not arrow_files:
            raise FileNotFoundError(f"No arrow files found in {cache_path}")

        texts = []
        for arrow_file in sorted(arrow_files):
            if len(texts) >= num_samples:
                break
            try:
                dataset = Dataset.from_file(arrow_file)
                for sample in dataset:
                    if len(texts) >= num_samples:
                        break
                    text = sample.get("text", "")
                    if len(text.strip()) > 50:
                        texts.append(text.strip()[:512])
            except Exception:
                continue

        info(f"Loaded {len(texts)} samples from local cache")
        return texts

    def _load_from_huggingface(self, num_samples: int) -> List[str]:
        """Load Wikipedia data from HuggingFace"""
        try:
            # Import HuggingFace datasets library
            import datasets
            load_dataset = datasets.load_dataset
        except ImportError as e:
            raise ImportError(f"Failed to import HuggingFace datasets library: {e}")

        dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
        texts = []
        for i, sample in enumerate(dataset):
            if i >= num_samples:
                break
            text = sample.get("text", "")
            if len(text.strip()) > 50:
                texts.append(text.strip()[:512])

        info(f"Loaded {len(texts)} samples from HuggingFace")
        return texts

    def _get_cache_path(self, expert_idx: int) -> Path:
        """Get cache file path for expert with layer information"""
        return self.cache_dir / f"projection_matrix_layer_{self.layer_idx}_expert_{expert_idx}.pkl"

    def _save_projection_matrix(self, expert_idx: int, matrix: torch.Tensor):
        """Save projection matrix to cache"""
        try:
            cache_path = self._get_cache_path(expert_idx)
            cache_data = ProjectionMatrixCache(
                matrix=matrix.cpu(),
                expert_idx=expert_idx,
                computation_time=datetime.now().isoformat(),
                matrix_shape=tuple(matrix.shape)
            )
            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
            debug(f"Saved projection matrix for expert {expert_idx}")
        except Exception as e:
            warning(f"Failed to save projection matrix for expert {expert_idx}: {e}")

    def _load_projection_matrix(self, expert_idx: int) -> Optional[torch.Tensor]:
        """Load projection matrix from cache"""
        try:
            cache_path = self._get_cache_path(expert_idx)
            if not cache_path.exists():
                return None

            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)

            # Handle both old and new cache formats
            if isinstance(cache_data, ProjectionMatrixCache):
                matrix = cache_data.matrix
            else:
                # Old format - just the matrix
                matrix = cache_data

            # Move to device
            matrix = matrix.to(self.device)
            debug(f"Loaded cached projection matrix for expert {expert_idx}")
            return matrix
        except Exception as e:
            warning(f"Failed to load projection matrix for expert {expert_idx}: {e}")
            return None

    def _get_unified_cache_path(self) -> Path:
        """Get cache file path for unified projection matrix with layer information"""
        return self.cache_dir / f"projection_matrix_layer_{self.layer_idx}_unified.pkl"

    def _get_covariance_cache_path(self, expert_idx: int) -> Path:
        """Get cache file path for expert covariance matrix with layer information"""
        return self.cache_dir / f"covariance_matrix_layer_{self.layer_idx}_expert_{expert_idx}.pkl"

    def _get_unified_covariance_cache_path(self) -> Path:
        """Get cache file path for unified covariance matrix with layer information"""
        return self.cache_dir / f"covariance_matrix_layer_{self.layer_idx}_unified.pkl"

    def _save_unified_projection_matrix(self, matrix: torch.Tensor):
        """Save unified projection matrix to cache"""
        try:
            cache_path = self._get_unified_cache_path()
            cache_data = ProjectionMatrixCache(
                matrix=matrix.cpu(),
                expert_idx=-1,
                computation_time=datetime.now().isoformat(),
                matrix_shape=tuple(matrix.shape)
            )
            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
            debug("Saved unified projection matrix")
        except Exception as e:
            warning(f"Failed to save unified projection matrix: {e}")

    def _load_unified_projection_matrix(self) -> Optional[torch.Tensor]:
        """Load unified projection matrix from cache"""
        try:
            cache_path = self._get_unified_cache_path()
            if not cache_path.exists():
                return None
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            if isinstance(cache_data, ProjectionMatrixCache):
                matrix = cache_data.matrix
            else:
                matrix = cache_data
            return matrix.to(self.device)
        except Exception as e:
            warning(f"Failed to load unified projection matrix: {e}")
            return None

    def get_unified_projection_matrix(self) -> Optional[torch.Tensor]:
        """Public getter for unified projection matrix (loads from cache if present)"""
        return self._load_unified_projection_matrix()

    def _save_covariance_matrix(self, expert_idx: int, cov_matrix: torch.Tensor):
        """Save expert covariance matrix to cache"""
        try:
            cache_path = self._get_covariance_cache_path(expert_idx)
            cache_data = {
                'covariance_matrix': cov_matrix.cpu(),
                'expert_idx': expert_idx,
                'computation_time': datetime.now().isoformat(),
                'matrix_shape': tuple(cov_matrix.shape)
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
            debug(f"Saved covariance matrix for expert {expert_idx}")
        except Exception as e:
            warning(f"Failed to save covariance matrix for expert {expert_idx}: {e}")

    def _save_unified_covariance_matrix(self, cov_matrix: torch.Tensor):
        """Save unified covariance matrix to cache"""
        try:
            cache_path = self._get_unified_covariance_cache_path()
            cache_data = {
                'covariance_matrix': cov_matrix.cpu(),
                'expert_idx': -1,  # Unified
                'computation_time': datetime.now().isoformat(),
                'matrix_shape': tuple(cov_matrix.shape)
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache_data, f)
            debug("Saved unified covariance matrix")
        except Exception as e:
            warning(f"Failed to save unified covariance matrix: {e}")

    def _load_covariance_matrix(self, expert_idx: int) -> Optional[torch.Tensor]:
        """Load expert covariance matrix from cache"""
        try:
            cache_path = self._get_covariance_cache_path(expert_idx)
            if not cache_path.exists():
                return None
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            return cache_data['covariance_matrix'].to(self.device)
        except Exception as e:
            warning(f"Failed to load covariance matrix for expert {expert_idx}: {e}")
            return None

    def _load_unified_covariance_matrix(self) -> Optional[torch.Tensor]:
        """Load unified covariance matrix from cache"""
        try:
            cache_path = self._get_unified_covariance_cache_path()
            if not cache_path.exists():
                return None
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            return cache_data['covariance_matrix'].to(self.device)
        except Exception as e:
            warning(f"Failed to load unified covariance matrix: {e}")
            return None

    def _can_hook_expert_down_projs(self) -> bool:
        """Detect whether per-expert down_proj modules exist (Qwen/Mixtral style) or not (gpt-oss packed experts)."""
        try:
            experts = getattr(self.target_layer, 'experts', None)
            if experts is None:
                return False
            # If experts is a list/ModuleList of modules, check first element
            if hasattr(experts, '__len__') and len(experts) > 0:
                try:
                    first = experts[0]
                    return hasattr(first, 'down_proj')
                except Exception:
                    return False
            # Not an indexable container => likely packed experts
            return False
        except Exception:
            return False

    def collect_keys_at_subject_positions(self, texts: List[str], expert_indices: List[int], subjects: List[str] = None, batch_size: int = 100) -> Dict[int, torch.Tensor]:
        """Collect keys at subject token positions using optimized batch processing"""
        from .stats_collector import DownProjHookCollector, Qwen3MoEStatisticsCollector
        import torch.nn.functional as F

        expert_keys = {idx: [] for idx in expert_indices}

        # Prepare text-subject pairs
        text_subject_pairs = list(zip(texts, subjects)) if subjects else [(text, None) for text in texts]

        # Process in batches for efficiency
        total_batches = (len(text_subject_pairs) + batch_size - 1) // batch_size

        # Detect if experts expose per-expert down_proj; if not (e.g., gpt-oss packed experts), skip key collection
        if not self._can_hook_expert_down_projs():
            info("Detected packed experts (no per-expert down_proj modules). Skipping key collection and using identity projection.")
            return {}

        # Create collector once and reuse across all batches
        collector = DownProjHookCollector(self.target_layer, expert_indices)

        try:
            # Setup hooks for all potential experts once (will be filtered during collection)
            collector.setup_hooks_for_experts(expert_indices)

            from .logger import get_tqdm_kwargs
            tqdm_kwargs = get_tqdm_kwargs(desc="Processing batches for key collection")

            for batch_idx in tqdm(range(total_batches), **tqdm_kwargs):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(text_subject_pairs))
                batch_pairs = text_subject_pairs[start_idx:end_idx]

                # Process this batch
                batch_expert_keys = self._process_batch_for_keys(
                    batch_pairs, expert_indices, collector
                )

                # Debug: analyze batch results
                batch_experts_with_data = len([k for k, v in batch_expert_keys.items() if v])
                batch_total_keys = sum(len(keys) for keys in batch_expert_keys.values())
                info(f"   Batch {batch_idx+1}: {len(batch_pairs)} samples → {batch_experts_with_data} experts → {batch_total_keys} keys")

                # Merge batch results into overall results
                for expert_idx, keys in batch_expert_keys.items():
                    expert_keys[expert_idx].extend(keys)

        finally:
            collector.cleanup()

        # Convert collected keys to tensors
        result = {}
        for expert_idx in expert_indices:
            if expert_keys[expert_idx]:
                keys_tensor = torch.cat(expert_keys[expert_idx], dim=0)
                result[expert_idx] = keys_tensor

        # Print detailed statistics about expert activation
        total_samples_processed = len(text_subject_pairs)
        experts_with_data = len(result)
        total_keys_collected = sum(keys.shape[0] for keys in result.values())

        info(f"📊 Batch Processing Summary:")
        info(f"   Total samples processed: {total_samples_processed}")
        info(f"   Experts with data: {experts_with_data}/{len(expert_indices)}")
        info(f"   Total keys collected: {total_keys_collected}")
        info(f"   Average keys per expert: {total_keys_collected/experts_with_data if experts_with_data > 0 else 0:.1f}")

        # Show distribution of keys per expert
        if result:
            key_counts = [keys.shape[0] for keys in result.values()]
            info(f"   Key distribution: min={min(key_counts)}, max={max(key_counts)}, avg={sum(key_counts)/len(key_counts):.1f}")

        info(f"Collected keys for {len(result)}/{len(expert_indices)} experts using batch processing optimization")
        return result

    def _process_batch_for_keys(self, batch_pairs: List[tuple], expert_indices: List[int], collector) -> Dict[int, List[torch.Tensor]]:
        """Process a batch of text-subject pairs to collect keys efficiently"""
        from .utils import find_subject_token_positions

        # Prepare batch data
        batch_texts = []
        batch_subjects = []
        batch_subject_positions = []
        valid_indices = []

        # Extract subjects and prepare texts
        for idx, (text, specified_subject) in enumerate(batch_pairs):
            if specified_subject:
                subject = specified_subject
            else:
                # Extract subject from text (simple heuristic)
                words = text.split()
                subject = next((word.strip('.,!?') for word in words
                              if word and word[0].isupper() and len(word) > 2), "")

            if not subject:
                continue

            batch_texts.append(text)
            batch_subjects.append(subject)
            valid_indices.append(idx)

        if not batch_texts:
            return {idx: [] for idx in expert_indices}

        # Batch tokenization
        batch_inputs = self.tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=256
        )
        batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}

        # Find subject positions for each sample in batch using AlphaEdit-style strategy
        for i, (text, subject) in enumerate(zip(batch_texts, batch_subjects)):
            try:
                # Create single-sample inputs for position finding
                single_inputs = {k: v[i:i+1] for k, v in batch_inputs.items()}

                # Use AlphaEdit-style token positioning if available
                if hasattr(self, 'fact_token_strategy'):
                    from .utils import get_last_subject_token_position_alphaedit
                    subject_end_pos = get_last_subject_token_position_alphaedit(
                        self.tokenizer, text, subject, single_inputs,
                        fact_token_strategy=self.fact_token_strategy, verbose=False
                    )
                else:
                    # Fallback to original method
                    _, subject_end_pos = find_subject_token_positions(
                        self.tokenizer, text, subject, single_inputs
                    )
            except Exception:
                # Simple fallback: use last token position
                subject_end_pos = batch_inputs['input_ids'].shape[1] - 1

            batch_subject_positions.append(subject_end_pos if subject_end_pos is not None else -1)

        # Get routing information for the entire batch
        batch_activated_experts = self._get_batch_activated_experts(
            batch_inputs, batch_subject_positions, expert_indices
        )

        # Clear collector state before batch processing
        collector.clear_captured_keys()

        # Single forward pass for the entire batch
        with torch.no_grad():
            self.model(**batch_inputs)

        # Extract keys for each sample and expert
        batch_expert_keys = {idx: [] for idx in expert_indices}

        for sample_idx, subject_pos in enumerate(batch_subject_positions):
            if subject_pos < 0:
                continue

            activated_experts = batch_activated_experts.get(sample_idx, [])

            for expert_idx in activated_experts:
                if expert_idx in expert_indices:
                    # Use similarity-based matching to get the correct key for this sample
                    # Pass batch_idx to support batch processing
                    key = collector.get_key_for_expert_and_position(
                        expert_idx, subject_pos, self.d_intermediate, batch_idx=sample_idx
                    )

                    if key is not None:
                        # Ensure key is 2D for concatenation
                        if key.dim() == 1:
                            key = key.unsqueeze(0)
                        batch_expert_keys[expert_idx].append(key)

        return batch_expert_keys

    def _get_batch_activated_experts(self, batch_inputs: Dict[str, torch.Tensor], batch_positions: List[int], expert_indices: List[int]) -> Dict[int, List[int]]:
        """Get activated experts for each sample in a batch"""
        from .stats_collector import Qwen3MoEStatisticsCollector

        try:
            # Create statistics collector to get routing information
            collector = Qwen3MoEStatisticsCollector()
            hooks = collector.create_hooks(self.target_layer)

            # Forward pass to capture routing for entire batch
            with torch.no_grad():
                _ = self.model(**batch_inputs, output_hidden_states=True)

            # Clean up hooks
            for hook in hooks:
                hook.remove()

            if not collector.router_logits:
                # Fallback: return all requested experts for all samples
                return {i: expert_indices for i in range(len(batch_positions))}

            router_logits = collector.router_logits[-1]  # [batch_size, seq_len, num_experts]

            # Handle different router_logits shapes
            if router_logits.dim() == 2:
                # Single sample case: [seq_len, num_experts]
                router_logits = router_logits.unsqueeze(0)  # [1, seq_len, num_experts]

            batch_activated_experts = {}

            for sample_idx, position in enumerate(batch_positions):
                if position < 0 or sample_idx >= router_logits.shape[0]:
                    batch_activated_experts[sample_idx] = []
                    continue

                if position >= router_logits.shape[1]:
                    batch_activated_experts[sample_idx] = []
                    continue

                # Get routing for this sample at this position
                sample_routing_probs = router_logits[sample_idx, position, :] + 10  # Add offset for stability
                sample_routing_probs = sample_routing_probs.float()
                _, sample_top_k_indices = torch.topk(sample_routing_probs, 8)  # Top-8 routing

                # Get activated expert indices for this sample
                activated_experts = []
                for k in range(8):
                    expert_idx = sample_top_k_indices[k].item()
                    if expert_idx in expert_indices:  # Only include requested experts
                        activated_experts.append(expert_idx)

                batch_activated_experts[sample_idx] = activated_experts

            return batch_activated_experts

        except Exception as e:
            warning(f"Failed to get batch activated experts: {e}")
            # Fallback: return all requested experts for all samples
            return {i: expert_indices for i in range(len(batch_positions))}

    def _get_activated_experts_at_position(self, inputs: Dict[str, torch.Tensor], position: int, expert_indices: List[int]) -> List[int]:
        """Get list of experts activated at specific token position"""
        from .stats_collector import Qwen3MoEStatisticsCollector
        import torch.nn.functional as F

        try:
            # Create statistics collector to get routing information
            collector = Qwen3MoEStatisticsCollector()
            hooks = collector.create_hooks(self.target_layer)

            # Forward pass to capture routing
            with torch.no_grad():
                _ = self.model(**inputs, output_hidden_states=True)

            # Clean up hooks
            for hook in hooks:
                hook.remove()

            if not collector.router_logits:
                return expert_indices  # Fallback to all requested experts

            router_logits = collector.router_logits[-1]  # [seq_len, num_experts]

            # Check if position is valid
            if position >= router_logits.shape[0]:
                return expert_indices  # Fallback to all requested experts

            # Get routing for the specific position
            pos_routing_probs = router_logits[position, :] + 10  # Add offset for stability
            pos_routing_probs = pos_routing_probs.float()
            pos_top_k_weights, pos_top_k_indices = torch.topk(pos_routing_probs, 8)  # Top-8 routing

            # Get activated expert indices
            activated_experts = []
            for k in range(8):
                expert_idx = pos_top_k_indices[k].item()
                if expert_idx in expert_indices:  # Only include requested experts
                    activated_experts.append(expert_idx)

            return activated_experts if activated_experts else expert_indices  # Fallback if none found

        except Exception as e:
            warning(f"Failed to get activated experts at position {position}: {e}")
            return expert_indices  # Fallback to all requested experts



    def _compute_null_space_projection(self, keys: torch.Tensor) -> tuple[torch.Tensor, int]:
        """Compute null space projection matrix using covariance SVD for better scalability

        Returns:
            tuple: (projection_matrix, null_space_dimension)
        """
        # Ensure keys are on GPU and properly shaped, convert to float32 to avoid BFloat16 issues
        keys = keys.to(self.device).float()

        if keys.numel() == 0:
            warning("Empty keys tensor, returning identity matrix")
            return torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32), 0

        actual_n_samples = keys.shape[0]
        # Reduce verbose logging for normal cases


        # AlphaEdit style: use non-centered second moment E[X X^T] (no centering)
        # This matches AlphaEdit's SecondMoment computation: mom2 += a.t().mm(a)

        # Use covariance SVD for better scalability with large sample sizes
        try:
            info(f"Using covariance SVD for {actual_n_samples} samples (memory efficient)")

            # Compute second moment matrix: E[X X^T] (AlphaEdit style)
            # Use actual number of samples for proper normalization
            cov_matrix = (keys.T @ keys) / actual_n_samples
            # cov_matrix shape: (d_intermediate, d_intermediate)

            # SVD of covariance matrix
            U, S, Vh = torch.linalg.svd(cov_matrix)

            # Find null space using threshold
            # For covariance matrix, small eigenvalues indicate null space
            null_mask = S <= self.nullspace_threshold

            if not null_mask.any():
                warning(f"No null space found, using smallest eigenvalue directions")
                # Use smallest 10% of eigenvalues as null space
                k = max(1, int(0.1 * len(S)))
                null_mask = torch.zeros_like(S, dtype=torch.bool)
                null_mask[-k:] = True

            # Null space projection: P = V_null @ V_null^T
            V_null = U[:, null_mask]  # Shape: (d_intermediate, null_dim)
            projection = V_null @ V_null.T

            null_space_dim = null_mask.sum().item()
            # Only log null space info for debugging or unusual cases
            if null_space_dim < 10 or null_space_dim > self.d_intermediate * 0.9:
                info(f"Covariance SVD: Null space dimension: {null_space_dim}/{self.d_intermediate}")

            return projection, null_space_dim

        except Exception as e:
            error(f"SVD computation failed: {e}, using identity matrix")
            return torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32), 0

    def _compute_null_space_projection_with_cov(self, keys: torch.Tensor) -> tuple[torch.Tensor, int, torch.Tensor]:
        """Compute null space projection matrix and return covariance matrix as well

        Returns:
            tuple: (projection_matrix, null_space_dimension, covariance_matrix)
        """
        # Ensure keys are on GPU and properly shaped, convert to float32 to avoid BFloat16 issues
        keys = keys.to(self.device).float()

        if keys.numel() == 0:
            warning("Empty keys tensor, returning identity matrix")
            identity = torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32)
            return identity, 0, identity

        n_samples = keys.shape[0]
        # Only log for very few samples or when debugging
        if n_samples < 10:
            warning(f"Very few samples ({n_samples}) for expert - may lead to low-rank projection")

        # Check key statistics (only warn for problematic cases)
        key_norms = torch.norm(keys, dim=1)
        zero_keys = (key_norms < 1e-8).sum().item()
        if zero_keys > 0:
            warning(f"Found {zero_keys} near-zero keys out of {n_samples} - may cause rank deficiency")

        # AlphaEdit style: use non-centered second moment E[X X^T] (no centering)
        # This matches AlphaEdit's SecondMoment computation: mom2 += a.t().mm(a)

        # Use actual number of samples in keys tensor for proper normalization
        actual_n_samples = keys.shape[0]

        # Use covariance SVD for better scalability with large sample sizes
        try:

            info(f"Using covariance SVD for {actual_n_samples} samples (memory efficient)")

            # Compute second moment matrix: E[X X^T] (AlphaEdit style)
            # Use actual number of samples for proper normalization
            cov_matrix = (keys.T @ keys) / actual_n_samples
            # cov_matrix shape: (d_intermediate, d_intermediate)

            # SVD of covariance matrix
            U, S, Vh = torch.linalg.svd(cov_matrix)

            # Find null space using threshold
            # For covariance matrix, small eigenvalues indicate null space
            null_mask = S <= self.nullspace_threshold

            if not null_mask.any():
                warning(f"No null space found, using smallest eigenvalue directions")
                # Use smallest 10% of eigenvalues as null space
                k = max(1, int(0.1 * len(S)))
                null_mask = torch.zeros_like(S, dtype=torch.bool)
                null_mask[-k:] = True

            # Null space projection: P = V_null @ V_null^T
            V_null = U[:, null_mask]  # Shape: (d_intermediate, null_dim)
            projection = V_null @ V_null.T

            null_space_dim = null_mask.sum().item()
            # Only log null space info for debugging or unusual cases
            if null_space_dim < 10 or null_space_dim > self.d_intermediate * 0.9:
                info(f"Covariance SVD: Null space dimension: {null_space_dim}/{self.d_intermediate}")

            return projection, null_space_dim, cov_matrix

        except Exception as e:
            error(f"SVD computation failed: {e}, using identity matrix")
            identity = torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32)
            return identity, 0, identity


    def _compute_null_space_projection_from_cov(self, cov_matrix: torch.Tensor) -> tuple[torch.Tensor, int]:
        """Compute null space projection given a precomputed covariance (second-moment) matrix.

        Args:
            cov_matrix: Tensor of shape (d_intermediate, d_intermediate), typically (X^T X) / N

        Returns:
            (projection_matrix, null_space_dimension)
        """
        try:
            cov = cov_matrix.to(self.device).float()
            U, S, Vh = torch.linalg.svd(cov)

            # Determine null space using threshold on eigenvalues (consistent with other path)
            null_mask = S <= self.nullspace_threshold
            if not null_mask.any():
                # Fallback: take the smallest 10% directions as null space
                k = max(1, int(0.1 * len(S)))
                null_mask = torch.zeros_like(S, dtype=torch.bool)
                null_mask[-k:] = True

            V_null = U[:, null_mask]
            projection = V_null @ V_null.T
            null_space_dim = null_mask.sum().item()
            return projection, null_space_dim
        except Exception as e:
            error(f"SVD on provided covariance failed: {e}; returning identity")
            identity = torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32)
            return identity, 0

    def compute_expert_projection_matrices(self,
                                         expert_indices: List[int],
                                         num_samples: int = 200,
                                         force_recompute: bool = False,
                                         dataset_name: str = "wikipedia") -> Dict[int, torch.Tensor]:
        """Compute projection matrices for specified experts and also compute unified matrix once.
        Expert matrices are cached per expert; unified matrix is cached as a single file.
        """
        info(f"Computing projection matrices for {len(expert_indices)} experts")

        # Load data once
        texts = self.get_wikipedia_style_texts(num_samples)

        # Check cache for each expert
        matrices_to_compute = []
        for expert_idx in expert_indices:
            if not force_recompute:
                cached_matrix = self._load_projection_matrix(expert_idx)
                if cached_matrix is not None:
                    self.projection_matrices[expert_idx] = cached_matrix
                    continue
            matrices_to_compute.append(expert_idx)

        # Collect keys for required experts (and aggregate for unified)
        expert_keys = {}
        if matrices_to_compute:
            expert_keys = self.collect_keys_at_subject_positions(texts, matrices_to_compute)

        # Track null space dimensions for summary
        null_space_stats = {}

        # Compute per-expert matrices and save covariance matrices
        for expert_idx in tqdm(matrices_to_compute, desc="Computing expert projection matrices"):
            if expert_idx in expert_keys:
                keys = expert_keys[expert_idx]
                projection_matrix, null_space_dim, cov_matrix = self._compute_null_space_projection_with_cov(keys)
                null_space_stats[expert_idx] = {
                    'null_space_dim': null_space_dim,
                    'total_dim': self.d_intermediate,
                    'num_samples': keys.shape[0]
                }
                # Save both projection matrix and covariance matrix
                self._save_covariance_matrix(expert_idx, cov_matrix)
            else:
                projection_matrix = torch.eye(self.d_intermediate, device=self.device, dtype=torch.float32)
                null_space_stats[expert_idx] = {
                    'null_space_dim': 0,
                    'total_dim': self.d_intermediate,
                    'num_samples': 0
                }

            self.projection_matrices[expert_idx] = projection_matrix
            self._save_projection_matrix(expert_idx, projection_matrix)

        # Compute unified matrix from all collected keys and save unified covariance
        try:
            # If we computed keys for some experts, aggregate them
            all_keys_list = []
            expert_key_counts = {}
            for idx in expert_indices:
                if idx in expert_keys:
                    expert_key_counts[idx] = expert_keys[idx].shape[0]
                    all_keys_list.append(expert_keys[idx])

            if all_keys_list:
                all_keys = torch.cat(all_keys_list, dim=0)
                total_keys = all_keys.shape[0]
                info(f"Computing unified projection matrix from {total_keys} total keys across {len(expert_key_counts)} experts")
                info(f"Expert key distribution: {dict(list(expert_key_counts.items())[:5])}...")  # Show first 5

                unified_matrix, _, unified_cov = self._compute_null_space_projection_with_cov(all_keys)
                self._save_unified_projection_matrix(unified_matrix)
                self._save_unified_covariance_matrix(unified_cov)
                info(f"✅ Unified projection matrix computed and saved (shape: {unified_matrix.shape})")
        except Exception as e:
            warning(f"Failed to compute/save unified projection matrix: {e}")

        # Print null space dimension summary
        self._print_null_space_summary(null_space_stats)

        return self.projection_matrices

    def _print_null_space_summary(self, null_space_stats: Dict[int, Dict[str, int]]):
        """Print a comprehensive summary of null space dimensions for all experts"""
        if not null_space_stats:
            return

        info("=" * 80)
        info("📊 NULL SPACE DIMENSION SUMMARY")
        info("=" * 80)

        # Calculate statistics
        total_experts = len(null_space_stats)
        experts_with_data = sum(1 for stats in null_space_stats.values() if stats['num_samples'] > 0)
        experts_without_data = total_experts - experts_with_data

        # Collect null space dimensions for experts with data
        null_dims = [stats['null_space_dim'] for stats in null_space_stats.values() if stats['num_samples'] > 0]
        sample_counts = [stats['num_samples'] for stats in null_space_stats.values() if stats['num_samples'] > 0]

        if null_dims:
            avg_null_dim = sum(null_dims) / len(null_dims)
            min_null_dim = min(null_dims)
            max_null_dim = max(null_dims)
            avg_samples = sum(sample_counts) / len(sample_counts)
            total_samples = sum(sample_counts)
        else:
            avg_null_dim = min_null_dim = max_null_dim = avg_samples = total_samples = 0

        # Print overall statistics
        info(f"📈 Overall Statistics:")
        info(f"   Total experts: {total_experts}")
        info(f"   Experts with data: {experts_with_data}")
        info(f"   Experts without data: {experts_without_data}")
        info(f"   Total samples collected: {total_samples}")
        info(f"   Average samples per expert: {avg_samples:.1f}")
        info("")

        # Print null space statistics
        if null_dims:
            info(f"🎯 Null Space Dimensions (out of {self.d_intermediate} total):")
            info(f"   Average: {avg_null_dim:.1f}")
            info(f"   Range: {min_null_dim} - {max_null_dim}")
            info(f"   Percentage: {avg_null_dim/self.d_intermediate*100:.1f}%")
            info("")

        # Print detailed breakdown
        info(f"📋 Detailed Breakdown:")
        info(f"{'Expert':<8} {'Samples':<8} {'Null Dim':<10} {'Total Dim':<10} {'Percentage':<12}")
        info("-" * 60)

        for expert_idx in sorted(null_space_stats.keys()):
            stats = null_space_stats[expert_idx]
            null_dim = stats['null_space_dim']
            total_dim = stats['total_dim']
            num_samples = stats['num_samples']
            percentage = (null_dim / total_dim * 100) if total_dim > 0 else 0

            if num_samples > 0:
                info(f"{expert_idx:<8} {num_samples:<8} {null_dim:<10} {total_dim:<10} {percentage:<12.1f}%")
            else:
                info(f"{expert_idx:<8} {'N/A':<8} {'Identity':<10} {total_dim:<10} {'0.0':<12}%")

        info("=" * 80)

    def get_projection_matrix(self, expert_idx: int) -> Optional[torch.Tensor]:
        """Get projection matrix for a specific expert"""
        if expert_idx in self.projection_matrices:
            return self.projection_matrices[expert_idx]

        # Try to load from cache
        cached_matrix = self._load_projection_matrix(expert_idx)
        if cached_matrix is not None:
            self.projection_matrices[expert_idx] = cached_matrix
            return cached_matrix

        return None

    def get_all_projection_matrices(self, expert_indices: List[int]) -> Dict[int, torch.Tensor]:
        """Get all projection matrices, computing if necessary"""
        missing_experts = [idx for idx in expert_indices if self.get_projection_matrix(idx) is None]

        if missing_experts:
            self.compute_expert_projection_matrices(missing_experts)

        return {idx: self.projection_matrices[idx] for idx in expert_indices
                if idx in self.projection_matrices}

    def clear_cache(self):
        """Clear all cached projection matrices and related files"""
        # Remove per-expert projection matrices (new naming with layer)
        for cache_file in self.cache_dir.glob("projection_matrix_layer_*_expert_*.pkl"):
            cache_file.unlink()
        # Remove unified projection matrix if present
        for cache_file in self.cache_dir.glob("projection_matrix_layer_*_unified.pkl"):
            cache_file.unlink()
        # Optionally remove covariance caches
        for cache_file in self.cache_dir.glob("covariance_matrix_layer_*_*.pkl"):
            cache_file.unlink()
        self.projection_matrices.clear()
        info("Cleared all cached projection/covariance matrices")

    def get_cache_info(self) -> Dict[str, Any]:
        """Get information about cached matrices"""
        cache_files = list(self.cache_dir.glob("projection_matrix_expert_*.pkl"))
        cached_experts = []

        for cache_file in cache_files:
            try:
                expert_idx = int(cache_file.stem.split('_')[-1])
                cached_experts.append(expert_idx)
            except ValueError:
                continue

        return {
            "cache_dir": str(self.cache_dir),
            "num_cached_experts": len(cached_experts),
            "cached_expert_indices": sorted(cached_experts),
            "cache_size_mb": sum(f.stat().st_size for f in cache_files) / (1024 * 1024)
        }


# ============================================================================
# Backward Compatibility Aliases
# ============================================================================

# Alias for backward compatibility
AlphaEditStyleProjectionManager = ProjectionMatrixManager
