"""
Main knowledge editor for Qwen3-30B-A3B MoE model
"""

import torch
import torch.nn.functional as F
from pathlib import Path
from typing import List, Dict, Any, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pickle
import re

from .utils import EditRequest, TargetVectorResult as UtilsTargetVectorResult, find_subject_token_positions, get_last_subject_token_position_alphaedit
from .stats_collector import Qwen3MoEStatisticsCollector, DownProjHookCollector, ProjectionMatrixCollector, GptOssExpertsForwardKeyCollector
from .optim import BlockCoordinateDescent

from .config import MoEEditConfig
from .compute_target_vectors import TargetVectorComputer, AlphaEditHyperParams, TargetVectorResult as ComputeTargetVectorResult
from .logger import get_logger, info, debug, error, warning
from .manager import HookManager, DownProjCollector, ProjectionMatrixManager


class Qwen3MoEKnowledgeEditor:
    """
    Qwen3-30B-A3B MoE Knowledge Editor
    Specifically targets down_proj layers in MoE experts
    Uses last subject token position for precise knowledge editing
    """

    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        config: Optional[MoEEditConfig] = None,
        layer_idx: Optional[int] = None,
        device: Optional[str] = None,
        lambda_reg: Optional[float] = None,
        projection_threshold: Optional[float] = None,
        stats_dir: Optional[str] = None,
    ):
        # Initialize from config or parameters
        if config is not None:
            # Use config values
            self.config = config
            self.device = device or config.device
            self.lambda_reg = lambda_reg or config.lambda_reg
            self.projection_threshold = projection_threshold or config.projection_threshold
            self.layer_idx = layer_idx or config.layer_idx
            self.stats_dir = Path(stats_dir or config.stats_dir)
            # Resolve model identifier for per-model cache subdirectories
            model_name_candidates = [
                getattr(config, 'model_name', None),
                getattr(config, 'model_path', None),
                getattr(model, 'name_or_path', None),
                getattr(tokenizer, 'name_or_path', None),
            ]
            model_id_raw = next((c for c in model_name_candidates if c), 'unknown_model')
            model_id_base = Path(str(model_id_raw)).name
            self.model_id = re.sub(r'[^A-Za-z0-9._-]+', '_', model_id_base)

            # Resolve dataset identifier for per-dataset cache subdirectories
            dataset_raw = getattr(config, 'dataset_name', None) or 'default'
            self.dataset_id = re.sub(r'[^A-Za-z0-9._-]+', '_', str(dataset_raw))

            # Set per-model/per-dataset cache directories for target vectors
            self.target_cache_dir = Path(config.target_cache_dir) / self.model_id / self.dataset_id
            self.target_cache_dir.mkdir(parents=True, exist_ok=True)

            # Model architecture parameters from config
            self.num_layers = config.num_layers
            self.num_experts = config.num_experts
            self.num_experts_per_tok = config.num_experts_per_tok
            self.d_model = config.d_model
            self.d_intermediate = config.d_intermediate
            self.d_hidden = config.d_hidden
        else:
            raise ValueError("Configuration must be provided or parameters must be set")

        self.model = model
        self.tokenizer = tokenizer
        self.stats_dir.mkdir(exist_ok=True)
        self.logger = get_logger()

        # Validate and set target layer
        if not (0 <= self.layer_idx < self.num_layers):
            raise ValueError(f"layer_idx must be between 0 and {self.num_layers-1}")
        self.target_layer = self.model.model.layers[self.layer_idx].mlp
        # 支持 Qwen3/Mixtral 风格 (gate+ModuleList experts) 与 gpt-oss 风格 (router+打包 experts)
        has_experts = hasattr(self.target_layer, 'experts')
        has_gate_or_router = hasattr(self.target_layer, 'gate') or hasattr(self.target_layer, 'router')
        if not (has_experts and has_gate_or_router):
            raise ValueError(f"Layer {self.layer_idx} is not a supported MoE layer (requires experts and gate/router)")

        # Get fact_token strategy from config (AlphaEdit compatible)
        self.fact_token_strategy = getattr(config, 'fact_token', 'subject_last')
        print(f"🎯 Token positioning strategy: {self.fact_token_strategy}")

        # Log initialization
        info(f"MoE Editor initialized")
        # Safely determine number of experts for both list-style and packed experts (e.g., gpt-oss)
        exp_mod = getattr(self.target_layer, 'experts', None)
        num_exp = getattr(exp_mod, 'num_experts', None)
        if num_exp is None:
            try:
                if hasattr(exp_mod, '__len__'):
                    num_exp = len(exp_mod)
                else:
                    dp = getattr(exp_mod, 'down_proj', None)
                    if hasattr(dp, 'shape') and len(dp.shape) == 3:
                        num_exp = int(dp.shape[0])
            except Exception:
                pass
        num_exp = num_exp or self.num_experts
        info(f"Editing layer {self.layer_idx} with {num_exp} experts")
        info(f"Dimensions: d_model={self.d_model}, d_intermediate={self.d_intermediate}")
        info(f"Lambda: {self.lambda_reg}, Device: {self.device}")

        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Projection matrices (per expert) - default to identity
        self.projection_matrices: Dict[int, torch.Tensor] = {}
        # Use model's expert_dim when available (e.g., gpt-oss), fallback to config d_intermediate//8
        _expert_dim_default = getattr(getattr(self.target_layer, 'experts', None), 'expert_dim', self.d_intermediate//8)
        self.projection_matrix = torch.eye(_expert_dim_default, device=self.device, dtype=torch.float32)  # Backward compatibility
        self.use_identity_projection = True

        # New strategy: store original projection matrices (never updated)
        self.original_projection_matrices: Dict[int, torch.Tensor] = {}

        # Accumulate historical keys for cumulative projection updates
        self.historical_keys: Dict[int, List[torch.Tensor]] = {
            j: [] for j in range(self.num_experts)
        }

        # Verify target vector cache directory exists
        if not self.target_cache_dir.exists():
            print(f"Warning: Target vector cache directory not found: {self.target_cache_dir}")
            print("Creating directory...")
            self.target_cache_dir.mkdir(parents=True, exist_ok=True)

        # Cache for target vector results
        self.target_results = {}

        # Infer per-expert key dimension (d_k) from runtime when available
        expert_intermediate_dim = getattr(getattr(self.target_layer, 'experts', None), 'expert_dim', self.d_hidden)

        # Initialize optimizer with exact d_k (avoid assuming top-k=8)
        self.bcd_optimizer = BlockCoordinateDescent(
            self.num_experts, self.d_model, expert_intermediate_dim, self.device, self.lambda_reg
        )

        # Initialize ProjectionMatrixManager (always create, even if using identity)
        self.projection_manager = ProjectionMatrixManager(
            model=self.model,
            tokenizer=self.tokenizer,
            target_layer=self.target_layer,
            num_experts=self.num_experts,
            d_intermediate=expert_intermediate_dim,
            layer_idx=self.layer_idx,
            device=self.device,
            cache_dir=str(Path(getattr(self.config, 'projection_cache_dir', './projection_matrix_cache')) / self.model_id),
            nullspace_threshold=getattr(self.config, 'nullspace_threshold', 2e-2),
            fact_token_strategy=getattr(self.config, 'fact_token', 'subject_last')
        )

    @classmethod
    def from_config(cls, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, config: MoEEditConfig):
        """Create editor from configuration"""
        return cls(model=model, tokenizer=tokenizer, config=config)

    def compute_projection_matrix(self, num_samples: int = 5000, use_identity: bool = True):
        """Compute unified null space projection matrix for all experts' down_proj layers (backward compatibility)"""
        self.compute_projection_matrices(num_samples, use_identity, unified=True)
        return self.projection_matrix

    def compute_projection_matrices(self, num_samples: int = 5000, use_identity: bool = True, unified: bool = False):
        """Compute null space projection matrices for each expert's down_proj layer

        Args:
            num_samples: Number of samples to collect for computing projection matrices
            use_identity: If True, use identity matrices for all experts
            unified: If True, compute a single unified matrix for all experts (backward compatibility)

        Returns:
            Dict[int, torch.Tensor] mapping expert_idx to projection matrix
        """

        # Check configuration for projection matrix updates
        update_projection_enabled = getattr(self.config, 'update_projection_matrices', False)
        print(f"📊 Projection matrix updates: {'ENABLED' if update_projection_enabled else 'DISABLED'}")

        # Default to using identity matrices
        if use_identity:
            print("Using identity matrices as projection matrices")
            _expert_dim_default = getattr(getattr(self.target_layer, 'experts', None), 'expert_dim', self.d_intermediate//8)
            identity_matrix = torch.eye(_expert_dim_default, device=self.device, dtype=torch.float32)
            self.projection_matrices = {i: identity_matrix.clone() for i in range(self.num_experts)}
            self.original_projection_matrices = {i: identity_matrix.clone() for i in range(self.num_experts)}
            self.projection_matrix = identity_matrix  # Backward compatibility
            self.use_identity_projection = True
            return self.projection_matrices

        # ProjectionMatrixManager is now created during initialization

        if unified:
            # Prefer loading unified matrix; if missing, compute per-expert (will also save unified), then load
            unified_P = self.projection_manager.get_unified_projection_matrix()
            if unified_P is None:
                print(f"Unified projection matrix not found in cache; computing from expert keys...")
                expert_indices = list(range(self.num_experts))
                _ = self.projection_manager.compute_expert_projection_matrices(
                    expert_indices=expert_indices,
                    num_samples=num_samples,
                    force_recompute=False,
                    dataset_name="wikipedia_style",
                )
                unified_P = self.projection_manager.get_unified_projection_matrix()
                if unified_P is None:
                    print("Warning: Unified projection matrix still missing; using identity")
                    unified_P = torch.eye(self.d_hidden, device=self.device, dtype=torch.float32)

            print(f"Using unified projection matrix with shape {unified_P.shape}")
            # Broadcast unified to all experts
            self.projection_matrices = {i: unified_P.clone() for i in range(self.num_experts)}
            self.original_projection_matrices = {i: unified_P.clone() for i in range(self.num_experts)}
            self.projection_matrix = unified_P
            self.use_identity_projection = False
            return self.projection_matrices

        # Non-unified: load/compute per-expert matrices (manager uses cache where possible)
        expert_indices = list(range(self.num_experts))
        self.projection_matrices = self.projection_manager.compute_expert_projection_matrices(
            expert_indices=expert_indices,
            num_samples=num_samples,
            force_recompute=False,
            dataset_name="wikipedia_style",
        )
        # Save original projection matrices (deep copy to avoid reference issues)
        self.original_projection_matrices = {
            i: matrix.clone() for i, matrix in self.projection_matrices.items()
        }
        # For backward compatibility, set a representative matrix
        if self.projection_matrices:
            self.projection_matrix = next(iter(self.projection_matrices.values()))
        else:
            self.projection_matrix = torch.eye(self.d_intermediate//8, device=self.device, dtype=torch.float32)
        self.use_identity_projection = False
        print(f"Computed/loaded projection matrices for {len(self.projection_matrices)} experts")
        return self.projection_matrices

    def _compute_projection_from_keys(self, keys_tensor: torch.Tensor) -> torch.Tensor:
        """Compute projection matrix from keys tensor"""
        # Center the data and convert to float32 for numerical stability
        keys_tensor = keys_tensor.float()
        keys_centered = keys_tensor - keys_tensor.mean(dim=0)
        cov = (keys_centered.T @ keys_centered) / keys_tensor.shape[0]

        # SVD to find null space
        U, S, _ = torch.linalg.svd(cov.float())

        # Find null space (small singular values)
        null_indices = S < self.projection_threshold
        if null_indices.sum() == 0:
            print("Warning: No null space found, using smallest 20% singular values")
            k = max(1, int(0.2 * len(S)))
            null_indices[-k:] = True

        V_null = U[:, null_indices]

        # Projection matrix P = V_null @ V_null.T
        projection_matrix = V_null @ V_null.T

        null_dim = null_indices.sum().item()
        print(f"Projection matrix computed. Null space dimension: {null_dim}/{self.d_intermediate//8}")

        return projection_matrix

    def _load_target_vectors_from_cache(self, requests: List[EditRequest]) -> Dict[int, ComputeTargetVectorResult]:
        """Load target vectors from cache files with batch computation for missing ones

        If a unified target layer is specified in config (target_layer_for_unified_vector),
        always use that layer's target vectors across all editing layers. Otherwise,
        fall back to per-layer target vectors using self.layer_idx.
        """
        results = {}
        missing_requests = []

        # Decide unified-vs-per-layer cache layer index
        unified_layer = getattr(self.config, 'target_layer_for_unified_vector', None)
        cache_layer_idx = unified_layer if isinstance(unified_layer, int) else self.layer_idx
        if unified_layer is not None and cache_layer_idx != self.layer_idx:
            print(f"🔁 Using unified target vectors from layer {cache_layer_idx} for editing layer {self.layer_idx}")

        # First pass: try to load cached vectors
        print(f"🔍 Checking cache for {len(requests)} requests...")
        for request in requests:
            cache_key = request.case_id
            cache_file = self.target_cache_dir / f"target_vector_layer_{cache_layer_idx}_{cache_key}.pkl"

            if cache_file.exists():
                try:
                    with open(cache_file, 'rb') as f:
                        data = pickle.load(f)
                        result = data
                        result.target_vector = result.target_vector.to(self.device)
                        results[request.case_id] = result
                        print(f"✅ Loaded cached target vector for case {request.case_id}")
                except Exception as e:
                    print(f"❌ Failed to load cache for case {request.case_id}: {e}")
                    missing_requests.append(request)
            else:
                print(f"📝 Cache missing for case {request.case_id}")
                missing_requests.append(request)

        # Second pass: batch compute missing target vectors (at cache_layer_idx)
        if missing_requests:
            print(f"🔄 Computing {len(missing_requests)} missing target vectors in batch (layer {cache_layer_idx})...")
            try:
                # Create target vector computer
                target_computer = TargetVectorComputer(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    cache_dir=str(self.target_cache_dir),
                    device=self.device,
                    fact_token_strategy=getattr(self, 'fact_token_strategy', 'subject_last'),
                    hparams=AlphaEditHyperParams(
                        v_lr=getattr(self.hparams, 'v_lr', 1e-1),
                        v_num_grad_steps=getattr(self.hparams, 'v_num_grad_steps', 50),
                        v_loss_layer=getattr(self.hparams, 'v_loss_layer', -1),
                        v_weight_decay=getattr(self.hparams, 'v_weight_decay', 1e-3),
                        kl_factor=getattr(self.hparams, 'kl_factor', 0.0625),
                        clamp_norm_factor=getattr(self.hparams, 'clamp_norm_factor', 4.0),
                        target_boost=getattr(self.hparams, 'target_boost', 3.0)
                    )
                )

                # Batch compute target vectors for missing requests
                batch_results = target_computer.compute_target_vectors(
                    requests=missing_requests,
                    layer_idx=cache_layer_idx,
                    loss_layer=getattr(self.hparams, 'v_loss_layer', None),
                    force_recompute=False
                )

                # Add batch results to main results
                results.update(batch_results)
                print(f"✅ Successfully computed {len(batch_results)} target vectors in batch")

            except Exception as e:
                print(f"❌ Batch target vector computation failed: {e}")
                import traceback
                traceback.print_exc()

                # Try individual computation for failed requests
                print(f"🔄 Attempting individual computation for {len(missing_requests)} requests...")
                for request in missing_requests:
                    if request.case_id not in results:
                        try:
                            # Try individual computation
                            individual_computer = TargetVectorComputer(
                                model=self.model,
                                tokenizer=self.tokenizer,
                                cache_dir=str(self.target_cache_dir),
                                device=self.device,
                                fact_token_strategy=getattr(self, 'fact_token_strategy', 'subject_last'),
                                hparams=AlphaEditHyperParams(
                                    v_lr=getattr(self.hparams, 'v_lr', 1e-1),
                                    v_num_grad_steps=getattr(self.hparams, 'v_num_grad_steps', 50),
                                    v_loss_layer=getattr(self.hparams, 'v_loss_layer', -1),
                                    v_weight_decay=getattr(self.hparams, 'v_weight_decay', 1e-3),
                                    kl_factor=getattr(self.hparams, 'kl_factor', 0.0625),
                                    clamp_norm_factor=getattr(self.hparams, 'clamp_norm_factor', 4.0),
                                    target_boost=getattr(self.hparams, 'target_boost', 3.0)
                                )
                            )
                            individual_results = individual_computer.compute_target_vectors(
                                requests=[request],
                                layer_idx=cache_layer_idx,
                                loss_layer=getattr(self.hparams, 'v_loss_layer', None),
                                force_recompute=False
                            )
                            if request.case_id in individual_results:
                                results[request.case_id] = individual_results[request.case_id]
                                print(f"✅ Individual computation succeeded for case {request.case_id}")
                            else:
                                raise Exception("Individual computation returned no result")

                        except Exception as individual_error:
                            print(f"❌ Individual computation also failed for case {request.case_id}: {individual_error}")
                            print(f"   Using zero vector fallback for case {request.case_id}")
                            result = UtilsTargetVectorResult(
                                target_vector=torch.zeros(self.d_model, device=self.device),
                                case_id=request.case_id,
                                metadata={'fallback': True, 'batch_error': str(e), 'individual_error': str(individual_error)}
                            )
                            results[request.case_id] = result

        print(f"📊 Target vector summary: {len(results)} total, {len(missing_requests)} computed, {len(requests) - len(missing_requests)} cached")
        return results

    def _verify_target_vectors(self, target_results: Dict[int, Any], requests: List[EditRequest]) -> Dict[int, Any]:
        """Verify target vectors by checking if they improve target token probability"""
        from .utils import pre_edit_prob

        print("🔍 Verifying target vectors effectiveness...")
        verified_results = {}

        for request in requests:
            case_id = request.case_id
            if case_id not in target_results:
                print(f"⚠️  No target vector found for case {case_id}, skipping verification")
                continue

            target_result = target_results[case_id]

            # Get baseline probability (before any editing)
            baseline_prob = pre_edit_prob(self.model, self.tokenizer, request, self.device)

            # Temporarily apply the target vector to test its effectiveness
            print(f"📋 Testing target vector for case {case_id}...")

            # Create a simple test by adding the target vector to the residual stream
            # This is a simplified test - in practice the target vector gets used in optimization
            test_prob = self._test_target_vector_effectiveness(request, target_result.target_vector)

            # Calculate improvement
            if baseline_prob > 0:
                improvement_ratio = test_prob / baseline_prob
                improvement_change = test_prob - baseline_prob

                print(f"  Baseline probability: {baseline_prob:.6f}")
                print(f"  Test probability: {test_prob:.6f}")
                print(f"  Improvement: {improvement_change:+.6f} ({improvement_ratio:.2f}x)")

                # Consider target vector valid if it shows some improvement
                if improvement_ratio > 1.1:  # At least 10% improvement
                    print(f"  ✅ Target vector is effective")
                    verified_results[case_id] = target_result
                elif improvement_ratio > 0.9:  # Not worse than 10% degradation
                    print(f"  ⚠️  Target vector shows marginal effect, keeping")
                    verified_results[case_id] = target_result
                else:
                    print(f"  ❌ Target vector appears ineffective, using fallback")
                    # Create a fallback zero vector using the correct TargetVectorResult from utils
                    fallback_result = UtilsTargetVectorResult(
                        target_vector=torch.zeros_like(target_result.target_vector),
                        case_id=case_id,
                        metadata={'verification_failed': True, 'original_improvement': improvement_ratio}
                    )
                    verified_results[case_id] = fallback_result
            else:
                print(f"  ⚠️  Baseline probability too low, keeping original target vector")
                verified_results[case_id] = target_result

        print(f"✅ Target vector verification completed: {len(verified_results)}/{len(requests)} vectors verified")
        return verified_results

    def _test_target_vector_effectiveness(self, request: EditRequest, target_vector: torch.Tensor) -> float:
        """Test target vector effectiveness by simulating its effect"""
        from .utils import find_subject_token_positions

        # Prepare input
        prompt_text = request.prompt.format(request.subject) if '{}' in request.prompt else request.prompt
        inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)

        # Find subject token positions using AlphaEdit-style strategy
        if hasattr(self, 'fact_token_strategy'):
            subject_end_pos = get_last_subject_token_position_alphaedit(
                self.tokenizer, request.prompt, request.subject, inputs,
                fact_token_strategy=self.fact_token_strategy, verbose=False
            )
        else:
            # Fallback to original method
            _, subject_end_pos = find_subject_token_positions(
                self.tokenizer, request.prompt, request.subject, inputs
            )

        # Get target token ID - use same method as compute_target_vector for consistency
        target_ids = self.tokenizer(request.target_new, return_tensors="pt").input_ids.to(self.device)
        if target_ids.shape[1] == 0:
            return 0.0
        target_token_id = target_ids[0, 0]

        # Get the original hidden state at the target position to compute delta
        original_hidden_state = None

        def capture_original_hook(module, input, output):
            # Ignore unused parameters
            _ = module, input
            nonlocal original_hidden_state
            if hasattr(output, 'last_hidden_state'):
                hidden_states = output.last_hidden_state
            else:
                hidden_states = output[0] if isinstance(output, tuple) else output

            if hidden_states.dim() == 3 and subject_end_pos < hidden_states.shape[1]:
                original_hidden_state = hidden_states[0, subject_end_pos, :].clone()
            return output

        # First pass: capture original hidden state
        with torch.no_grad():
            target_layer = self.model.model.layers[self.layer_idx]
            capture_hook = target_layer.register_forward_hook(capture_original_hook)

            try:
                _ = self.model(**inputs)
            finally:
                capture_hook.remove()

        if original_hidden_state is None:
            return 0.0

        # Calculate delta: target_vector is the desired output (target_init + delta)
        # So delta = target_vector - original_hidden_state
        delta = target_vector - original_hidden_state

        # Second pass: apply the delta intervention
        with torch.no_grad():
            # Hook to apply delta intervention (same as in compute_target_vector)
            def intervention_hook(module, input, output):
                # Ignore unused parameters
                _ = module, input
                if hasattr(output, 'last_hidden_state'):
                    hidden_states = output.last_hidden_state
                else:
                    hidden_states = output[0] if isinstance(output, tuple) else output

                # Apply delta to the subject token position (same logic as compute_target_vector)
                if hidden_states.dim() == 3 and subject_end_pos < hidden_states.shape[1]:
                    hidden_states[0, subject_end_pos, :] += delta.to(hidden_states.dtype)

                return output

            # Register intervention hook
            hook_handle = target_layer.register_forward_hook(intervention_hook)

            try:
                # Forward pass with delta intervention
                outputs = self.model(**inputs)
                logits = outputs.logits[0, -1, :]  # Last token logits
                probs = torch.softmax(logits, dim=-1)
                target_prob = probs[target_token_id].item()

            finally:
                hook_handle.remove()

        return target_prob

    def collect_edit_statistics(
        self,
        requests: List[EditRequest],
        num_prefixes: int = 50,
        use_prefix_expansion: bool = True,
        verify_targets: bool = False
    ) -> Dict[str, Any]:
        """Collect statistics for BCD optimization using last subject token position with optional prefix expansion"""
        if self.projection_matrix is None or not self.projection_matrices:
            self.compute_projection_matrices()

        all_requests = requests

        if use_prefix_expansion:
            # AlphaEdit-style context templates: wrap the full prompt rather than simple prefixes
            context_templates = [
                "{}",
                "In general, {}",
                "According to available information, {}",
                "It is well known that {}",
                "People often say that {}",
                "Research indicates that {}",
                "From reports, {}",
            ]
            # Truncate to requested number
            if num_prefixes < len(context_templates):
                context_templates = context_templates[:num_prefixes]

            expanded_samples = []
            for request in all_requests:
                base_prompt = request.prompt.format(request.subject) if '{}' in request.prompt else request.prompt
                for j, tmpl in enumerate(context_templates):
                    wrapped_prompt = tmpl.format(base_prompt)
                    expanded_samples.append((request, j, wrapped_prompt))

            batch_size = len(expanded_samples)  # Now batch_size = |requests| * num_prefixes
            print(f"Context template expansion enabled: {len(all_requests)} requests → {batch_size} samples ({len(context_templates)} templates each)")
        else:
            # Use original requests without prefix expansion
            expanded_samples = [(req, 0, req.prompt.format(req.subject) if '{}' in req.prompt else req.prompt)
                              for req in all_requests]
            batch_size = len(expanded_samples)
            print(f"Prefix expansion disabled: using {batch_size} original samples")

        stats = {
            'gating': torch.zeros(batch_size, self.num_experts, device=self.device),
            'keys': [],  # down_proj inputs (raw k_{i,j})
            'values': [],  # projected keys (v_{i,j} = P_j @ k_{i,j})
            'residuals': [],
            'targets': [],
            'is_new': [],  # Will be filled based on original request type
            'subject_positions': [],  # Store subject token positions for each sample
            'entity_mapping': [],  # Maps sample index t to original entity index i
            'prefix_mapping': []   # Maps sample index t to prefix index j
        }

        collector = Qwen3MoEStatisticsCollector()
        hooks = collector.create_hooks(self.target_layer)

        # Setup hook collector for down_proj inputs (skip if packed experts like gpt-oss)
        down_proj_collector = None
        try:
            down_proj_collector = DownProjHookCollector(self.target_layer, self.num_experts)
            down_proj_collector.setup_hooks()
        except Exception as e:
            print(f"[MoEEdit] Skip down_proj hooks (likely packed experts such as gpt-oss): {e}")

        # Detect gpt-oss packed experts and prepare forward wrapper for key capture
        wrapper = None
        try:
            if hasattr(self.target_layer, 'experts') and hasattr(self.target_layer.experts, 'gate_up_proj'):
                wrapper = GptOssExpertsForwardKeyCollector(self.target_layer.experts)
        except Exception:
            wrapper = None

        print("Collecting edit statistics...")

        # Load target vectors from cache
        print(f"🔍 Loading target vectors for {len(all_requests)} requests...")
        self.target_results = self._load_target_vectors_from_cache(all_requests)
        print(f"📊 Target vector status: {len(self.target_results)} loaded from cache")

        # Verify target vectors if enabled
        if verify_targets:
            # Lightweight verification only when explicitly requested
            self.target_results = self._verify_target_vectors(self.target_results, all_requests)

        try:
            with torch.no_grad():
                for t, (original_req, prefix_idx, prefixed_prompt) in enumerate(expanded_samples):
                    collector.clear()
                    if down_proj_collector is not None:
                        down_proj_collector.clear_captured_keys()

                    # All samples are new since we removed historical_requests
                    original_idx = all_requests.index(original_req)
                    is_new_sample = True  # All samples are new
                    stats['is_new'].append(is_new_sample)
                    stats['entity_mapping'].append(original_idx)  # Map t -> i
                    stats['prefix_mapping'].append(prefix_idx)    # Map t -> j

                    # Get pre-computed target vector (same for all prefixes of the same entity)
                    if original_req.case_id in self.target_results:
                        target_result = self.target_results[original_req.case_id]
                        z_vec = target_result.target_vector.to(self.device)

                        # Check if this is a fallback target vector
                        is_fallback = hasattr(target_result, 'metadata') and target_result.metadata and target_result.metadata.get('fallback', False)
                        vector_type = "FALLBACK" if is_fallback else "OPTIMIZED"
                        print(f"Using {vector_type} target vector for sample {t+1} (entity {original_idx}, prefix {prefix_idx}) - norm: {z_vec.norm().item():.2f}")
                    else:
                        print(f"⚠ No target vector found for sample {t+1} (case_id: {original_req.case_id})")
                        print(f"   Available target results: {list(self.target_results.keys())}")
                        print(f"   Attempting to compute target vector on-the-fly...")

                        # Try to compute target vector on-the-fly
                        try:
                            target_computer = TargetVectorComputer(
                                model=self.model,
                                tokenizer=self.tokenizer,
                                cache_dir=str(self.target_cache_dir),
                                device=self.device,
                                fact_token_strategy=getattr(self, 'fact_token_strategy', 'subject_last'),
                                hparams=AlphaEditHyperParams(
                                    v_lr=getattr(self.hparams, 'v_lr', 1e-1),
                                    v_num_grad_steps=getattr(self.hparams, 'v_num_grad_steps', 50),
                                    v_loss_layer=getattr(self.hparams, 'v_loss_layer', -1),
                                    v_weight_decay=getattr(self.hparams, 'v_weight_decay', 1e-3),
                                    kl_factor=getattr(self.hparams, 'kl_factor', 0.0625),
                                    clamp_norm_factor=getattr(self.hparams, 'clamp_norm_factor', 4.0),
                                    target_boost=getattr(self.hparams, 'target_boost', 3.0)
                                )
                            )
                            on_the_fly_results = target_computer.compute_target_vectors(
                                requests=[original_req],
                                layer_idx=self.layer_idx,
                                loss_layer=getattr(self.hparams, 'v_loss_layer', None),
                                force_recompute=False
                            )
                            if original_req.case_id in on_the_fly_results:
                                target_result = on_the_fly_results[original_req.case_id]
                                self.target_results[original_req.case_id] = target_result
                                z_vec = target_result.target_vector.to(self.device)
                                print(f"✅ Successfully computed target vector on-the-fly for case {original_req.case_id}")
                            else:
                                print(f"❌ On-the-fly computation failed, using zero vector")
                                z_vec = torch.zeros(self.d_model, device=self.device)
                        except Exception as e:
                            print(f"❌ On-the-fly computation error: {e}")
                            print(f"   Using zero vector as final fallback")
                            z_vec = torch.zeros(self.d_model, device=self.device)

                    # Use prefixed prompt
                    inputs = self.tokenizer(prefixed_prompt, return_tensors="pt",
                                          padding=True, truncation=True).to(self.device)

                    # Find subject token positions using AlphaEdit-style strategy
                    if hasattr(self, 'fact_token_strategy'):
                        subject_end_pos = get_last_subject_token_position_alphaedit(
                            self.tokenizer, original_req.prompt, original_req.subject, inputs,
                            fact_token_strategy=self.fact_token_strategy, verbose=False
                        )
                        # For compatibility, also compute start position
                        if self.fact_token_strategy == "subject_last":
                            subject_start_pos = get_last_subject_token_position_alphaedit(
                                self.tokenizer, original_req.prompt, original_req.subject, inputs,
                                fact_token_strategy="subject_first", verbose=False
                            )
                        else:
                            subject_start_pos = subject_end_pos
                    else:
                        # Fallback to original method
                        subject_start_pos, subject_end_pos = find_subject_token_positions(
                            self.tokenizer, original_req.prompt, original_req.subject, inputs
                        )
                    stats['subject_positions'].append((subject_start_pos, subject_end_pos))

                    # If using gpt-oss wrapper, attach and set capture positions (flat index for batch size 1)
                    if wrapper is not None and subject_end_pos is not None and subject_end_pos >= 0:
                        try:
                            wrapper.clear()
                            wrapper.set_capture_positions([int(subject_end_pos)])
                            wrapper.attach()
                        except Exception as _e:
                            pass

                    with torch.no_grad():
                        # Single forward pass to capture router and down_proj inputs via hooks/wrapper
                        outputs = self.model(**inputs, output_hidden_states=True)

                        # Extract routing information (subject token only) and fill gating
                        if collector.router_logits and collector.moe_inputs:
                            router_logits = collector.router_logits[-1]  # [seq_len, num_experts] or [batch, seq_len, num_experts]
                            if router_logits.dim() == 3:
                                router_logits = router_logits[0]
                            routing_probs = router_logits[subject_end_pos, :].float()
                            routing_probs = routing_probs + 10  # keep behavior
                            routing_probs = F.softmax(routing_probs, dim=-1)
                            top_k_weights, top_k_indices = torch.topk(routing_probs, self.num_experts_per_tok)
                            renormalized_weights = top_k_weights / top_k_weights.sum()

                            for k in range(self.num_experts_per_tok):
                                idx_val = top_k_indices[k].item()
                                weight_val = renormalized_weights[k].item()
                                if idx_val < self.num_experts:
                                    stats['gating'][t, idx_val] = weight_val

                            # Use hook-captured real keys and compute projected values
                            expert_keys = []
                            expert_values = []
                            d_k = getattr(getattr(self.target_layer, 'experts', None), 'expert_dim', self.d_intermediate//8)
                            for j in range(self.num_experts):
                                if stats['gating'][t, j] > 0:
                                    key = None
                                    if down_proj_collector is not None:
                                        try:
                                            key = down_proj_collector.get_key_for_expert_and_position(j, subject_end_pos, d_k)
                                        except Exception:
                                            key = None
                                    # Fallback to gpt-oss forward wrapper if available
                                    if key is None and wrapper is not None and subject_end_pos is not None and subject_end_pos >= 0:
                                        try:
                                            key = wrapper.get_key_for_expert_and_flat_index(j, int(subject_end_pos))
                                        except Exception:
                                            key = None
                                    if key is not None:
                                        # Ensure key is on the same device as projection matrices/model
                                        key = key.to(self.device)
                                        if j in self.projection_matrices:
                                            P_j = self.projection_matrices[j].to(self.device)
                                            value = (P_j @ key.float()).to(dtype=key.dtype)
                                        else:
                                            value = key.clone()
                                        expert_keys.append(key)
                                        expert_values.append(value)
                                    else:
                                        expert_keys.append(None)
                                        expert_values.append(None)
                                else:
                                    expert_keys.append(None)
                                    expert_values.append(None)

                            stats['keys'].append(expert_keys)
                            stats['values'].append(expert_values)

                            # Compute target and residual
                            _ = outputs.logits[0, -1, :]  # logits not used in current implementation

                            # Get target token (same for all prefixes of the same entity)
                            target_ids = self.tokenizer(original_req.target_new, return_tensors="pt").input_ids

                            # get current hidden state at subject token position
                            current_hidden_state = outputs.hidden_states[self.layer_idx+1][0, subject_end_pos, :]
                            print(f"Sample {t+1}: Current hid state: {[f'{x:.4f}' for x in current_hidden_state[:8].tolist()]}")

                            # Compute residual with optional residual weight for multi-layer editing
                            residual = current_hidden_state - z_vec.float()

                            # Apply residual weight if this is part of multi-layer editing
                            if hasattr(self, 'current_residual_weight') and self.current_residual_weight is not None:
                                residual = residual * self.current_residual_weight
                                print(f"Sample {t+1}: Applied residual weight {self.current_residual_weight:.3f}")

                            stats['residuals'].append(residual)
                            stats['targets'].append(target_ids.clone().detach())

        finally:
            # Remove hooks
            for hook in hooks:
                hook.remove()
            if down_proj_collector is not None:
                down_proj_collector.cleanup()
            # Detach gpt-oss wrapper if used
            if 'wrapper' in locals() and wrapper is not None:
                try:
                    wrapper.detach()
                except Exception:
                    pass

        return stats

    def optimize(self, stats: Dict, method: str = 'bcd', **kwargs) -> List[torch.Tensor]:

        if method == 'bcd':
            num_passes = kwargs.get('num_passes', 2)
            incremental = kwargs.get('incremental', False)
            return self.bcd_optimizer.optimize(
                stats,
                num_passes=num_passes,
                projection_matrices=self.projection_matrices,
                incremental=incremental
            )
        else:
            raise ValueError(f"Unknown optimization method: {method}")

    def apply_weight_updates(self, deltas: List[torch.Tensor], use_expert_projection: bool = True):
        """Apply weight updates to down_proj layers

        Args:
            deltas: List of delta tensors for each expert
            use_expert_projection: If True, apply expert-specific projection before updating weights
        """
        print("Applying weight updates to down_proj layers...")

        with torch.no_grad():
            updated_count = 0

            for j, delta in enumerate(deltas):
                if delta.norm().item() < 1e-6:
                    continue

                experts_mod = getattr(self.target_layer, 'experts', None)

                # Case 1: list/ModuleList experts (Qwen/Mixtral style)
                if hasattr(experts_mod, '__len__') and not hasattr(experts_mod, 'down_proj'):
                    try:
                        expert = experts_mod[j]
                    except Exception:
                        continue

                    if hasattr(expert, 'down_proj') and hasattr(expert.down_proj, 'weight'):
                        down_proj = expert.down_proj

                        # Apply projection matrix according to BCD formula: W_j^new = W_j + Delta_j * P_j
                        if use_expert_projection and j in self.projection_matrices and not self.use_identity_projection:
                            P_j = self.projection_matrices[j].to(self.device)
                            # Delta_j is (d_model, d_k), P_j is (d_k, d_k)
                            projected_delta = delta @ P_j
                            weight_update = projected_delta.to(dtype=down_proj.weight.dtype, device=down_proj.weight.device)
                            print(f"Expert {j}: projected norm {projected_delta.norm().item():.4f}")
                        else:
                            # No projection (identity case or safety fallback)
                            weight_update = delta.to(dtype=down_proj.weight.dtype, device=down_proj.weight.device)
                            print(f"Expert {j}: delta norm {delta.norm().item():.4f} (no projection)")

                        down_proj.weight.data += weight_update
                        updated_count += 1

                # Case 2: packed experts (e.g., gpt-oss) with shared parameter tensor [E, d_k, d_model]
                elif hasattr(experts_mod, 'down_proj'):
                    dp_param = experts_mod.down_proj  # nn.Parameter [E, d_k, d_model]
                    if isinstance(dp_param, torch.nn.Parameter) and dp_param.dim() == 3 and j < dp_param.shape[0]:
                        if use_expert_projection and j in self.projection_matrices and not self.use_identity_projection:
                            P_j = self.projection_matrices[j].to(dp_param.device)
                            projected_delta = delta @ P_j  # (d_model, d_k)
                            upd = projected_delta.to(dtype=dp_param.dtype, device=dp_param.device)
                            print(f"Expert {j}: projected norm {projected_delta.norm().item():.4f} (packed)")
                        else:
                            upd = delta.to(dtype=dp_param.dtype, device=dp_param.device)
                            print(f"Expert {j}: delta norm {delta.norm().item():.4f} (packed, no projection)")
                        # Transpose to match [d_k, d_model]
                        upd_t = upd.T.contiguous()
                        dp_param.data[j].add_(upd_t)
                        updated_count += 1

                else:
                    # Fallback: cannot locate expert down_proj module in this architecture
                    continue

            print(f"Updated {updated_count}/{self.num_experts} down_proj layers")

    def edit_knowledge(
        self,
        requests: List[EditRequest],
        method: str = None,  # Use config default
        num_passes: int = None,  # Use config default
        verify_targets: bool = None,  # Use config default
        num_prefixes: int = None,  # Use config default
        use_prefix_expansion: bool = None,  # Whether to use prefix expansion (None = use config)
        auto_verify: bool = False,  # Default off for speed (currently unused)
        incremental: bool = None,  # Use config default
        save_to_history: bool = None  # Whether to save current edits to historical data (None = same as incremental)
    ):
        """Main knowledge editing entry point with configurable prefix expansion and auto-verification"""

        # Use config defaults for None parameters
        if method is None:
            method = getattr(self.config, 'method', 'bcd')
        if num_passes is None:
            num_passes = getattr(self.config, 'num_passes', 2)
        if verify_targets is None:
            verify_targets = getattr(self.config, 'verify_targets', False)
        if num_prefixes is None:
            num_prefixes = getattr(self.config, 'num_prefixes', 10)
        if incremental is None:
            incremental = getattr(self.config, 'incremental_editing', False)

        # Ensure save_to_history is consistent with incremental setting
        if save_to_history is None:
            save_to_history = incremental

        print(f"🔧 Editing mode: incremental={incremental}, save_to_history={save_to_history}")

        # Determine prefix expansion setting from config if not explicitly provided
        if use_prefix_expansion is None:
            # Try multiple possible config field names for compatibility
            use_prefix_expansion = getattr(self.config, 'enable_prefix_expansion',
                                         getattr(self.config, 'prefix_enabled', True))

        print(f"🔧 Prefix expansion setting: use_prefix_expansion={use_prefix_expansion}, num_prefixes={num_prefixes}")

        # Log editing configuration
        if use_prefix_expansion:
            print(f"🚀 Starting batch knowledge editing for {len(requests)} requests with {num_prefixes} prefixes each...")
            print(f"📊 Total samples to process: {len(requests) * num_prefixes}")
            print(f"🔧 Prefix expansion: ENABLED")
        else:
            print(f"🚀 Starting batch knowledge editing for {len(requests)} requests (prefix expansion disabled)...")
            print(f"📊 Total samples to process: {len(requests)}")
            print(f"🔧 Prefix expansion: DISABLED")

        # Skip pre-edit probabilities for speed (could be used for verification if needed)
        pre_edit_probs = None  # Reserved for future use

        # Collect statistics with configurable prefix expansion
        stats = self.collect_edit_statistics(requests, num_prefixes, use_prefix_expansion, verify_targets)

        # Run optimization
        if method == 'bcd':
            deltas = self.optimize(stats, method='bcd', num_passes=num_passes, incremental=incremental)
        elif method == 'gd':
            deltas = self.optimize(stats, method='gd', num_steps=num_passes*50, incremental=incremental)  # Convert passes to steps
        else:
            raise ValueError(f"Unknown method: {method}")

        self.apply_weight_updates(deltas, use_expert_projection=not self.use_identity_projection)

        # Update projection matrices using AlphaEdit-style formula
        update_projection_enabled = getattr(self.config, 'update_projection_matrices', True)
        if update_projection_enabled and not self.use_identity_projection:
            self.update_projection_matrices(stats)
        else:
            if self.use_identity_projection:
                print("📊 Projection matrix updates skipped (using identity matrices)")
            else:
                print("📊 Projection matrix updates disabled by configuration")

        # Auto-verify edits if enabled
        # if auto_verify:
        #     from .utils import verify_edits
        #     print("\n🔍 Verifying edits...")
        #     verification_results = verify_edits(
        #         self.model,
        #         self.tokenizer,
        #         requests,
        #         pre_edit_probs,
        #         getattr(self, 'target_results', None),
        #         self.layer_idx,
        #         self.device
        #     )

        #     # Summary statistics
        #     improved_count = sum(1 for r in verification_results if r['prob_change'] and r['prob_change'] > 0)
        #     total_count = len([r for r in verification_results if r['prob_change'] is not None])

        #     print(f"\n=== Verification Summary ===")
        #     print(f"Total edits: {len(verification_results)}")
        #     print(f"Improved: {improved_count}/{total_count}")
        #     print(f"Success rate: {improved_count/total_count*100:.1f}%" if total_count > 0 else "N/A")

        #     return verification_results

        # Save current editing data to historical storage for future incremental edits
        if save_to_history and hasattr(self, 'bcd_optimizer'):
            self.bcd_optimizer.add_historical_data(stats)
            print(f"💾 Current editing data saved to historical storage")
        elif not save_to_history:
            print(f"📊 Historical data saving DISABLED (incremental editing disabled)")

        # Projection matrix updates completed (statistics computation disabled for performance)
        if update_projection_enabled and not self.use_identity_projection:
            print(f"📊 Projection matrix updates completed")

        # Return deltas for compatibility with evaluation framework
        # Note: deltas were already applied to the model in apply_weight_updates
        return deltas if deltas is not None else []

    def reset_optimization_state(self):
        """Reset cumulative optimization state and historical keys"""
        if hasattr(self, 'bcd_optimizer'):
            self.bcd_optimizer.reset_cumulative_state()

        # Reset historical keys for new cumulative strategy
        self.clear_historical_keys()

        # Reset projection matrices to original state
        if hasattr(self, 'original_projection_matrices'):
            self.projection_matrices = {
                i: matrix.clone() for i, matrix in self.original_projection_matrices.items()
            }
            print("🔄 Reset projection matrices to original state")

        print("🔄 Optimization state and historical keys reset")

    def save_optimization_state(self, filepath: str):
        """Save cumulative optimization state"""
        if hasattr(self, 'bcd_optimizer'):
            self.bcd_optimizer.save_state(filepath)
            print(f"💾 Optimization state saved to {filepath}")

    def load_optimization_state(self, filepath: str):
        """Load cumulative optimization state"""
        if hasattr(self, 'bcd_optimizer'):
            self.bcd_optimizer.load_state(filepath)
            print(f"📂 Optimization state loaded from {filepath}")

    def update_projection_matrices(self, stats: Dict[str, Any]):
        """Update projection matrices using cumulative recomputation strategy

        New strategy: Always compute updates based on original projection matrices P₀
        and all accumulated historical keys, rather than incrementally updating.

        Formula: P' = P₀ - P₀ K_all G_all† K_all^T P₀
        where K_all contains all keys from current and previous epochs.

        Args:
            stats: Statistics containing keys for current epoch
        """
        # Check and initialize original projection matrices if needed
        if not hasattr(self, 'original_projection_matrices') or not self.original_projection_matrices:
            print("⚠️  No original projection matrices found, initializing from current matrices...")
            if hasattr(self, 'projection_matrices') and self.projection_matrices:
                # Use current projection matrices as original (fallback)
                self.original_projection_matrices = {
                    i: matrix.clone() for i, matrix in self.projection_matrices.items()
                }
                print(f"✅ Initialized original projection matrices for {len(self.original_projection_matrices)} experts")
            else:
                print("❌ No projection matrices available, skipping update")
                return

        # Check if projection matrix updates are enabled
        update_enabled = getattr(self.config, 'update_projection_matrices', False)
        if not update_enabled:
            print("📊 Projection matrix updates disabled by configuration")
            return

        print("🔄 Updating projection matrices using cumulative recomputation strategy...")

        # Step 1: Extract keys from current epoch
        current_epoch_keys = self._extract_keys_from_stats(stats)

        # Step 2: Accumulate into historical storage
        self._accumulate_historical_keys(current_epoch_keys)

        # Step 3: Update each expert's projection matrix using all historical keys
        updated_count = 0
        skipped_count = 0
        filtered_count = 0

        for j in range(self.num_experts):
            if j not in self.original_projection_matrices:
                continue

            # Get all historical keys for this expert
            all_historical_keys = self.historical_keys[j]
            if not all_historical_keys:
                continue  # No keys for this expert

            # Stack all historical keys into matrix K_all: (d_k, num_total_samples)
            K_all = torch.stack(all_historical_keys, dim=1).to(self.device)

            if K_all.numel() == 0:
                continue

            # Get original projection matrix (never changes)
            P0 = self.original_projection_matrices[j].to(self.device)

            # Step 1: Compute R = P₀ @ K_all
            R = P0 @ K_all  # (d_k, num_total_samples)

            # Check if R has meaningful content
            R_norm = torch.norm(R).item()
            if R_norm < 1e-10:
                continue  # Skip if R is essentially zero

            # === Significance filtering ===
            # Compute relative energy ratio for each key: ρ(k) = ||P₀ k|| / ||k||
            k_norms = torch.norm(K_all, dim=0) + 1e-12  # Add small epsilon to avoid division by zero
            r_norms = torch.norm(R, dim=0)
            rho = r_norms / k_norms

            # Use projection_threshold as significance threshold τ_rel
            tau_rel = float(self.projection_threshold)
            mask = rho >= tau_rel

            # Ensure minimum number of keys for numerical stability
            min_keys = min(3, K_all.shape[1])
            if mask.sum().item() < min_keys:
                # Keep the top-k keys with highest relative energy
                _, top_indices = torch.topk(rho, k=min_keys)
                mask = torch.zeros_like(rho, dtype=torch.bool)
                mask[top_indices] = True

            # Check if projection matrix update is needed
            mean_rho = rho.mean().item()
            if mean_rho < 0.5 * tau_rel or mask.sum().item() == 0:
                # Most keys have very low energy in original projection space
                # Skip updating this expert's projection matrix
                skipped_count += 1
                continue

            # Filter keys and projected keys by significance mask
            original_num_keys = K_all.shape[1]
            K_filtered = K_all[:, mask]
            R_filtered = R[:, mask]
            filtered_num_keys = K_filtered.shape[1]

            if filtered_num_keys < original_num_keys:
                filtered_count += 1

            # Step 2: Compute Gram matrix G = K_filtered^T @ R_filtered
            G = K_filtered.T @ R_filtered  # (filtered_samples, filtered_samples)

            # Step 3: Compute pseudoinverse G†
            try:
                G_dag = torch.linalg.pinv(G)  # (filtered_samples, filtered_samples)
            except:
                print(f"⚠️  Failed to compute pseudoinverse for expert {j}, skipping")
                continue

            # Step 4: Apply cumulative update P' = P₀ - R_filtered @ G† @ R_filtered^T
            update_matrix = R_filtered @ G_dag @ R_filtered.T  # (d_k, d_k)
            P_updated = P0 - update_matrix

            # Update the current projection matrix (but keep P₀ unchanged)
            self.projection_matrices[j] = P_updated.cpu()
            updated_count += 1

            # Log update statistics
            update_norm = torch.norm(update_matrix).item()
            current_epoch_keys_count = len(current_epoch_keys.get(j, []))
            print(f"  Expert {j}: Updated projection matrix (total keys: {filtered_num_keys}/{original_num_keys}, "
                  f"current epoch: {current_epoch_keys_count}, mean_ρ: {mean_rho:.4f}, update_norm: {update_norm:.6f})")

        print(f"✅ Updated {updated_count}/{self.num_experts} projection matrices using cumulative strategy")
        if skipped_count > 0:
            print(f"📊 Skipped {skipped_count} experts (projection already well-adapted)")
        if filtered_count > 0:
            print(f"📊 Applied significance filtering to {filtered_count} experts")

    def _extract_keys_from_stats(self, stats: Dict[str, Any]) -> Dict[int, List[torch.Tensor]]:
        """Extract keys from statistics for each expert"""
        current_epoch_keys = {j: [] for j in range(self.num_experts)}

        for sample_keys in stats['keys']:
            for j, key in enumerate(sample_keys):
                if key is not None:
                    current_epoch_keys[j].append(key.clone())

        return current_epoch_keys

    def _accumulate_historical_keys(self, current_epoch_keys: Dict[int, List[torch.Tensor]]):
        """Accumulate current epoch keys into historical storage"""
        for j, keys_list in current_epoch_keys.items():
            if keys_list:
                self.historical_keys[j].extend(keys_list)
                print(f"  Expert {j}: Added {len(keys_list)} keys, total historical: {len(self.historical_keys[j])}")

    def clear_historical_keys(self):
        """Clear all historical keys (useful for starting fresh)"""
        self.historical_keys = {j: [] for j in range(self.num_experts)}
        print("🔄 Cleared all historical keys")

    def initialize_original_projection_matrices(self):
        """Initialize original projection matrices from current ones (if missing)"""
        if hasattr(self, 'projection_matrices') and self.projection_matrices:
            self.original_projection_matrices = {
                i: matrix.clone() for i, matrix in self.projection_matrices.items()
            }
            print(f"✅ Initialized original projection matrices for {len(self.original_projection_matrices)} experts")
        else:
            print("❌ No current projection matrices to copy from")

    def get_projection_matrix_statistics(self, compute_eigenvals: bool = False) -> Dict[str, Any]:
        """Get statistics about current projection matrices

        Args:
            compute_eigenvals: Whether to compute expensive eigenvalue statistics (default: False)
        """
        if not hasattr(self, 'projection_matrices') or not self.projection_matrices:
            return {}

        stats = {}
        for j, P in self.projection_matrices.items():
            P_norm = torch.norm(P).item()
            P_rank = torch.linalg.matrix_rank(P).item()

            expert_stats = {
                "norm": P_norm,
                "rank": P_rank,
                "trace": torch.trace(P).item(),
            }

            # Only compute expensive eigenvalue statistics if requested
            if compute_eigenvals:
                eigenvals = torch.linalg.eigvals(P.float()).real
                expert_stats.update({
                    "eigenvalue_min": eigenvals.min().item(),
                    "eigenvalue_max": eigenvals.max().item(),
                    "condition_number": (eigenvals.max() / (eigenvals.min() + 1e-10)).item()
                })

            stats[f"expert_{j}"] = expert_stats

        return stats
