"""
Multi-layer knowledge editor for Qwen3-30B-A3B MoE model
Implements sequential multi-layer editing with unified target vector computation
"""

import torch
import torch.nn.functional as F
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pickle
import math
import re

from .utils import EditRequest, TargetVectorResult as UtilsTargetVectorResult
from .stats_collector import Qwen3MoEStatisticsCollector, DownProjHookCollector
from .optim import BlockCoordinateDescent
from .config import MoEEditConfig
from .compute_target_vectors import TargetVectorComputer, AlphaEditHyperParams, TargetVectorResult as ComputeTargetVectorResult
from .logger import get_logger, info, debug, error, warning
from .manager import HookManager, DownProjCollector, ProjectionMatrixManager
from .editor import Qwen3MoEKnowledgeEditor


class MultiLayerMoEKnowledgeEditor:
    """
    Multi-layer MoE Knowledge Editor

    Implements the multi-layer update strategy:
    1. Use target layer (e.g., layer 13) to compute unified target vector
    2. Compute projection matrices for all editing layers (e.g., 11, 12, 13) before editing
    3. For each batch, update layers sequentially:
       - Layer 11: r_11 = (target_vector - h_11) / 3
       - Layer 12: r_12 = (target_vector - h_12) / 2  (after layer 11 update)
       - Layer 13: r_13 = (target_vector - h_13) / 1  (after layer 11, 12 updates)
    """

    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        config: MoEEditConfig,
        device: Optional[str] = None,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.device = device or config.device
        self.logger = get_logger()

        # Resolve model identifier for per-model cache subdirectories
        model_name_candidates = [
            getattr(config, 'model_name', None),
            getattr(config, 'model_path', None),
            getattr(model, 'name_or_path', None),
            getattr(tokenizer, 'name_or_path', None),
        ]
        model_id_raw = next((c for c in model_name_candidates if c), 'unknown_model')
        model_id_base = Path(str(model_id_raw)).name
        self.model_id = re.sub(r'[^A-Za-z0-9._-]+', '_', model_id_base)

        # Resolve dataset identifier for per-dataset target vector cache subdirectories
        dataset_raw = getattr(config, 'dataset_name', None) or 'default'
        self.dataset_id = re.sub(r'[^A-Za-z0-9._-]+', '_', str(dataset_raw))

        # Use base roots; let per-layer editors append model/dataset
        self.base_target_cache_dir = Path(config.target_cache_dir)
        self.base_projection_cache_dir = Path(getattr(config, 'projection_cache_dir', './projection_matrix_cache'))

        # Validate multi-layer configuration
        if not config.multi_layer_enabled:
            raise ValueError("MultiLayerMoEKnowledgeEditor requires multi_layer_enabled=True in config")

        self.editing_layers = config.get_editing_layers()
        self.target_layer_idx = config.get_target_layer_for_vector_computation()

        if len(self.editing_layers) < 2:
            raise ValueError("Multi-layer editing requires at least 2 layers")

        if self.target_layer_idx not in self.editing_layers:
            raise ValueError(f"Target layer {self.target_layer_idx} must be in editing layers {self.editing_layers}")

        # Sort editing layers for sequential processing
        self.editing_layers = sorted(self.editing_layers)

        info(f"Multi-layer editor initialized")
        info(f"Editing layers: {self.editing_layers}")
        info(f"Target layer for vector computation: {self.target_layer_idx}")

        # Create single-layer editors for each editing layer
        self.layer_editors = {}
        for layer_idx in self.editing_layers:
            layer_config = self._create_layer_config(config, layer_idx)
            self.layer_editors[layer_idx] = Qwen3MoEKnowledgeEditor(
                model=model,
                tokenizer=tokenizer,
                config=layer_config,
                device=device
            )

        # Target vector computer (uses target layer)
        tv_cache_dir = self.base_target_cache_dir / self.model_id / self.dataset_id
        tv_cache_dir.mkdir(parents=True, exist_ok=True)
        self.target_computer = TargetVectorComputer(
            model=model,
            tokenizer=tokenizer,
            cache_dir=str(tv_cache_dir),
            device=self.device,
            fact_token_strategy=config.fact_token,
            hparams=AlphaEditHyperParams(
                v_lr=getattr(config, 'v_lr', 1e-1),
                v_num_grad_steps=getattr(config, 'v_num_grad_steps', 50),
                v_loss_layer=getattr(config, 'v_loss_layer', -1),
                v_weight_decay=getattr(config, 'v_weight_decay', 1e-3),
                kl_factor=getattr(config, 'kl_factor', 0.0625),
                clamp_norm_factor=getattr(config, 'clamp_norm_factor', 4.0),
                target_boost=getattr(config, 'target_boost', 3.0)
            )
        )

        # Cache for computed target vectors and projection matrices
        self.target_vectors_cache = {}
        self.projection_matrices_cache = {}

        # Flag to track if projection matrices have been computed
        self._projection_matrices_computed = False

    def _create_layer_config(self, base_config: MoEEditConfig, layer_idx: int) -> MoEEditConfig:
        """Create a layer-specific config for single-layer editor

        For multi-layer editing, we defer projection matrix computation to the multi-layer
        editor to avoid redundant calculations and ensure consistency.
        """
        # Create a copy of the base config but with single layer
        layer_config = MoEEditConfig(
            model_path=base_config.model_path,
            device=base_config.device,
            layer_idx=layer_idx,
            layers=[layer_idx],  # Single layer for this editor
            lambda_reg=base_config.lambda_reg,
            projection_threshold=base_config.projection_threshold,
            nullspace_threshold=base_config.nullspace_threshold,
            stats_dir=base_config.stats_dir,
            # Pass base roots; single-layer editor will append model/dataset
            target_cache_dir=base_config.target_cache_dir,
            projection_cache_dir=getattr(base_config, 'projection_cache_dir', './projection_matrix_cache'),
            dataset_name=getattr(base_config, 'dataset_name', None),
            num_layers=base_config.num_layers,
            num_experts=base_config.num_experts,
            num_experts_per_tok=base_config.num_experts_per_tok,
            d_model=base_config.d_model,
            d_intermediate=base_config.d_intermediate,
            d_hidden=base_config.d_hidden,
            multi_layer_enabled=False,  # Disable multi-layer for individual editors
            # Use config setting for consistency
            use_identity=base_config.use_identity,  # Use config setting
            num_samples=base_config.num_samples,  # Keep consistent sample count
            use_expert_projection=base_config.use_expert_projection,  # Use config setting
            method=base_config.method,
            num_passes=base_config.num_passes,
            learning_rate=base_config.learning_rate,
            verify_targets=base_config.verify_targets,
            prefix_enabled=base_config.prefix_enabled,
            enable_prefix_expansion=base_config.enable_prefix_expansion,
            num_prefixes=base_config.num_prefixes,
            eval_batch_size=base_config.eval_batch_size,
            compare_pre_post=base_config.compare_pre_post,
            incremental_editing=base_config.incremental_editing,
            update_projection_matrices=base_config.update_projection_matrices,
            fact_token=base_config.fact_token,
            target_layer_for_unified_vector=getattr(base_config, 'target_layer_for_unified_vector', None)
        )
        return layer_config

    def compute_unified_target_vectors(
        self,
        requests: List[EditRequest],
        force_recompute: bool = False
    ) -> Dict[int, ComputeTargetVectorResult]:
        """
        Compute unified target vectors using the target layer

        Args:
            requests: List of edit requests
            force_recompute: Force recomputation even if cached

        Returns:
            Dictionary mapping case_id to TargetVectorResult
        """
        info(f"Computing unified target vectors using layer {self.target_layer_idx}")

        # Use target layer to compute unified target vectors
        target_vectors = self.target_computer.compute_target_vectors(
            requests=requests,
            layer_idx=self.target_layer_idx,
            loss_layer=getattr(self.config, 'v_loss_layer', None),
            force_recompute=force_recompute
        )

        self.target_vectors_cache.update(target_vectors)
        info(f"Computed {len(target_vectors)} unified target vectors")

        return target_vectors

    def compute_all_projection_matrices(
        self,
        requests: List[EditRequest],
        num_samples: int = None,  # Will use config value if None
        force_recompute: bool = False
    ) -> Dict[int, Dict[int, torch.Tensor]]:
        """
        Compute projection matrices for all editing layers before editing

        Args:
            requests: List of edit requests
            num_samples: Number of samples for projection matrix computation
            force_recompute: Force recomputation even if cached

        Returns:
            Dictionary mapping layer_idx -> expert_idx -> projection_matrix
        """
        # Use config value if num_samples is not explicitly provided
        if num_samples is None:
            num_samples = self.config.num_samples

        info(f"Computing projection matrices for layers {self.editing_layers}")
        info(f"🔍 DEBUG: compute_all_projection_matrices called with num_samples = {num_samples} (from config: {self.config.num_samples})")
        info(f"🔧 use_unified_projection={getattr(self.config, 'use_unified_projection', False)}")

        # Check if we've already computed projection matrices for this session
        if self._projection_matrices_computed and not force_recompute:
            info("✅ Projection matrices already computed in this session, using cached results")
            return self.projection_matrices_cache

        # Check which layers need computation
        layers_to_compute = []
        for layer_idx in self.editing_layers:
            layer_editor = self.layer_editors[layer_idx]
            expert_indices = list(range(layer_editor.num_experts))

            # Check if any expert needs computation
            needs_computation = force_recompute
            missing_experts = []
            if not needs_computation:
                for expert_idx in expert_indices:
                    if layer_editor.projection_manager.get_projection_matrix(expert_idx) is None:
                        print(f'Missing cache for expert {expert_idx}')
                        missing_experts.append(expert_idx)
                        needs_computation = True

                if missing_experts:
                    info(f"🔍 Layer {layer_idx}: Missing cache for {len(missing_experts)} experts: {missing_experts[:10]}{'...' if len(missing_experts) > 10 else ''}")
                else:
                    info(f"✅ Layer {layer_idx}: All {len(expert_indices)} experts have cached projection matrices")

            if needs_computation:
                layers_to_compute.append(layer_idx)

        if not layers_to_compute:
            info("All projection matrices are cached, skipping computation")
            projection_matrices = {}
            for layer_idx in self.editing_layers:
                layer_editor = self.layer_editors[layer_idx]
                expert_indices = list(range(layer_editor.num_experts))
                # Load per-expert matrices from cache
                layer_matrices = {idx: layer_editor.projection_manager.get_projection_matrix(idx)
                                for idx in expert_indices}
                # Optionally replace with unified matrix if requested and available
                if getattr(self.config, 'use_unified_projection', False):
                    unified = layer_editor.projection_manager.get_unified_projection_matrix()
                    if unified is not None:
                        layer_matrices = {idx: unified.clone() for idx in expert_indices}
                        info(f"🔁 Using unified projection matrix for layer {layer_idx} for all {len(expert_indices)} experts (shape {tuple(unified.shape)})")
                    else:
                        warning(f"Unified projection matrix not found for layer {layer_idx}; using per-expert matrices")

                projection_matrices[layer_idx] = layer_matrices

                # IMPORTANT: Set use_identity_projection = False and load matrices
                layer_editor.projection_matrices = layer_matrices
                layer_editor.use_identity_projection = False
                info(f"✅ Layer {layer_idx}: Set use_identity_projection=False, loaded {len(layer_matrices)} cached projection matrices")

            self.projection_matrices_cache.update(projection_matrices)
            self._projection_matrices_computed = True
            return projection_matrices

        info(f"Need to compute projection matrices for layers: {layers_to_compute}")

        info("Using streaming covariance accumulation per batch (default)")
        all_layer_covs = self._collect_covariances_for_all_layers(
            layers_to_compute, num_samples, force_recompute
        )

        # Compute projection matrices for each layer using collected data
        projection_matrices = {}

        # Add progress tracking for layer processing
        from tqdm import tqdm
        layer_progress = tqdm(
            self.editing_layers,
            desc="Computing projection matrices by layer",
            unit="layer"
        )

        for layer_idx in layer_progress:
            layer_progress.set_description(f"Processing layer {layer_idx}")

            if layer_idx in layers_to_compute:
                layer_editor = self.layer_editors[layer_idx]
                expert_indices = list(range(layer_editor.num_experts))

                layer_covs = all_layer_covs.get(layer_idx, {})
                layer_matrices = self._compute_projection_matrices_from_covariances(
                    layer_editor, expert_indices, layer_covs, force_recompute
                )
                projection_matrices[layer_idx] = layer_matrices

                experts_with_data = len(layer_covs)
                layer_progress.set_postfix({
                    'experts': f"{experts_with_data}/128",
                    'matrices': len(layer_matrices)
                })
            else:
                # Load from cache
                layer_editor = self.layer_editors[layer_idx]
                expert_indices = list(range(layer_editor.num_experts))
                layer_matrices = {idx: layer_editor.projection_manager.get_projection_matrix(idx)
                                for idx in expert_indices}
                projection_matrices[layer_idx] = layer_matrices

                layer_progress.set_postfix({'status': 'cached'})

        self.projection_matrices_cache.update(projection_matrices)
        self._projection_matrices_computed = True

        # If requested, attempt to replace per-expert matrices with unified per-layer matrices
        if getattr(self.config, 'use_unified_projection', False):
            for layer_idx in self.editing_layers:
                layer_editor = self.layer_editors[layer_idx]
                expert_indices = list(range(layer_editor.num_experts))
                unified = layer_editor.projection_manager.get_unified_projection_matrix()
                if unified is not None:
                    projection_matrices[layer_idx] = {idx: unified.clone() for idx in expert_indices}
                    info(f"🔁 Using unified projection matrix for layer {layer_idx} for all {len(expert_indices)} experts (shape {tuple(unified.shape)})")
                else:
                    warning(f"Unified projection matrix not found for layer {layer_idx}; keeping per-expert matrices")

        # IMPORTANT: Set use_identity_projection = False for all layer editors
        # so they will actually use the computed projection matrices
        for layer_idx in self.editing_layers:
            layer_editor = self.layer_editors[layer_idx]
            layer_editor.projection_matrices = projection_matrices[layer_idx]
            layer_editor.use_identity_projection = False
            info(f"✅ Layer {layer_idx}: Set use_identity_projection=False, loaded {len(projection_matrices[layer_idx])} projection matrices")

        info("✅ Projection matrices computation completed and cached")
        return projection_matrices

    def _collect_keys_for_all_layers(
        self,
        layer_indices: List[int],
        num_samples: int,
        force_recompute: bool = False
    ) -> Dict[int, Dict[int, torch.Tensor]]:
        """
        Optimized key collection for multiple layers in a single forward pass

        Args:
            layer_indices: List of layer indices to collect keys for
            num_samples: Number of samples to use
            force_recompute: Whether to force recomputation

        Returns:
            Dictionary mapping layer_idx -> expert_idx -> keys_tensor
        """
        from .stats_collector import MultiLayerDownProjCollector
        import torch

        info(f"🚀 Optimized key collection for {len(layer_indices)} layers in single pass")
        info(f"🔍 DEBUG: Requested num_samples = {num_samples}")

        # Get Wikipedia-style texts (use first layer's manager for data loading)
        first_layer_manager = self.layer_editors[layer_indices[0]].projection_manager
        texts = first_layer_manager.get_wikipedia_style_texts(num_samples)
        info(f"🔍 DEBUG: Actually loaded {len(texts)} texts")

        # Create multi-layer collector
        target_layers = [self.layer_editors[layer_idx].target_layer for layer_idx in layer_indices]
        collector = MultiLayerDownProjCollector(target_layers, layer_indices)

        try:
            # Setup hooks for all layers
            collector.setup_hooks()

            # Process texts in batches
            batch_size = 100
            all_layer_keys = {layer_idx: {expert_idx: [] for expert_idx in range(128)}
                            for layer_idx in layer_indices}

            total_batches = (len(texts) + batch_size - 1) // batch_size
            info(f"Processing {len(texts)} texts in {total_batches} batches")

            # Add progress bar for key collection
            from tqdm import tqdm
            from .logger import get_tqdm_kwargs

            tqdm_kwargs = get_tqdm_kwargs(
                desc="Collecting keys from batches",
                total=total_batches,
                unit="batch"
            )
            batch_progress = tqdm(range(0, len(texts), batch_size), **tqdm_kwargs)

            for batch_start in batch_progress:
                batch_end = min(batch_start + batch_size, len(texts))
                batch_texts = texts[batch_start:batch_end]

                # Process batch and collect keys for all layers
                batch_keys = self._process_batch_for_all_layers(
                    batch_texts, layer_indices, collector
                )

                # Accumulate keys
                for layer_idx in layer_indices:
                    for expert_idx in range(128):
                        if expert_idx in batch_keys.get(layer_idx, {}):
                            all_layer_keys[layer_idx][expert_idx].extend(
                                batch_keys[layer_idx][expert_idx]
                            )

                # Update progress bar description with current stats
                total_keys_so_far = sum(
                    len(keys) for layer_keys in all_layer_keys.values()
                    for keys in layer_keys.values()
                )
                batch_progress.set_postfix({
                    'total_keys': total_keys_so_far,
                    'batch_size': len(batch_texts)
                })

            # Convert lists to tensors
            result = {}
            for layer_idx in layer_indices:
                layer_result = {}
                for expert_idx in range(128):
                    if all_layer_keys[layer_idx][expert_idx]:
                        keys_tensor = torch.cat(all_layer_keys[layer_idx][expert_idx], dim=0)
                        layer_result[expert_idx] = keys_tensor
                result[layer_idx] = layer_result

                # Log statistics
                experts_with_data = len(layer_result)
                total_keys = sum(keys.shape[0] for keys in layer_result.values())
                info(f"Layer {layer_idx}: {experts_with_data}/128 experts, {total_keys} total keys")

            return result

        finally:
            collector.cleanup()



    def _collect_covariances_for_all_layers(
        self,
        layer_indices: List[int],
        num_samples: int,
        force_recompute: bool = False
    ) -> Dict[int, Dict[int, Dict[str, Any]]]:
        """
        Streaming accumulation of per-expert covariance (second-moment) matrices across batches
        for multiple layers in a single forward pass.

        Returns:
            Dict[layer_idx][expert_idx] = {
                'sum_cov': Tensor (d x d),  # sum of K_b^T @ K_b across batches (not normalized)
                'count': int                # total number of key vectors aggregated
            }
        """
        from .stats_collector import MultiLayerDownProjCollector
        import torch

        info(f"🚀 Streaming covariance accumulation for {len(layer_indices)} layers")
        # Load corpus once (use first layer's manager for dataset)
        first_layer_manager = self.layer_editors[layer_indices[0]].projection_manager
        texts = first_layer_manager.get_wikipedia_style_texts(num_samples)
        info(f"🔍 DEBUG: Actually loaded {len(texts)} texts")

        # Setup multi-layer collector hooks
        target_layers = [self.layer_editors[layer_idx].target_layer for layer_idx in layer_indices]
        collector = MultiLayerDownProjCollector(target_layers, layer_indices)

        # Storage for covariance sums and counts
        all_layer_covs: Dict[int, Dict[int, Dict[str, Any]]] = {layer_idx: {} for layer_idx in layer_indices}

        try:
            collector.setup_hooks()

            batch_size = 100
            total_batches = (len(texts) + batch_size - 1) // batch_size
            from .logger import get_tqdm_kwargs
            from tqdm import tqdm
            tqdm_kwargs = get_tqdm_kwargs(desc="Collecting streaming covariances", total=total_batches, unit="batch")

            for batch_start in tqdm(range(0, len(texts), batch_size), **tqdm_kwargs):
                batch_end = min(batch_start + batch_size, len(texts))
                batch_texts = texts[batch_start:batch_end]

                # Reuse existing batch processing to obtain per-layer/per-expert keys for subject positions
                batch_keys = self._process_batch_for_all_layers(batch_texts, layer_indices, collector)

                # For each layer/expert, accumulate this batch's K^T K and counts
                for layer_idx in layer_indices:
                    layer_dict = all_layer_covs[layer_idx]
                    for expert_idx in range(128):
                        keys_list = batch_keys.get(layer_idx, {}).get(expert_idx, [])
                        if not keys_list:
                            continue
                        K_b = torch.cat(keys_list, dim=0).float()  # [n_b, d]
                        cov_b = K_b.T @ K_b                          # [d, d]
                        n_b = K_b.shape[0]

                        entry = layer_dict.get(expert_idx)
                        if entry is None:
                            layer_dict[expert_idx] = {'sum_cov': cov_b, 'count': int(n_b)}
                        else:
                            entry['sum_cov'] = entry['sum_cov'] + cov_b
                            entry['count'] += int(n_b)

        finally:
            collector.cleanup()

        # Log summary
        for layer_idx in layer_indices:
            experts_with_data = len(all_layer_covs[layer_idx])
            total_keys = sum(v['count'] for v in all_layer_covs[layer_idx].values())
            info(f"Layer {layer_idx}: {experts_with_data}/128 experts, {total_keys} total keys (streaming)")

        return all_layer_covs

    def _compute_projection_matrices_from_covariances(
        self,
        layer_editor,
        expert_indices: List[int],
        layer_covs: Dict[int, Dict[str, Any]],
        force_recompute: bool = False
    ) -> Dict[int, torch.Tensor]:
        """
        Compute projection matrices per expert from pre-aggregated covariance sums and counts.

        Args:
            layer_editor: Single layer editor
            expert_indices: List of expert indices
            layer_covs: Dict[expert_idx] -> {'sum_cov': Tensor, 'count': int}
            force_recompute: Whether to force recomputation

        Returns:
            Dict[expert_idx] -> projection matrix
        """
        from tqdm import tqdm
        import torch

        projection_matrices: Dict[int, torch.Tensor] = {}

        expert_progress = tqdm(
            expert_indices,
            desc=f"Layer {layer_editor.layer_idx} projection matrices (streaming)",
            unit="expert"
        )

        cached_count = 0
        computed_count = 0
        identity_count = 0

        # For unified accumulation
        unified_sum = None
        unified_count = 0

        for expert_idx in expert_progress:
            # Cache check
            if not force_recompute:
                cached_matrix = layer_editor.projection_manager._load_projection_matrix(expert_idx)
                if cached_matrix is not None:
                    projection_matrices[expert_idx] = cached_matrix
                    cached_count += 1
                    expert_progress.set_postfix({'cached': cached_count, 'computed': computed_count, 'identity': identity_count})
                    continue

            # Compute from accumulated covariance if available
            if expert_idx in layer_covs and layer_covs[expert_idx]['count'] > 0:
                cov_sum = layer_covs[expert_idx]['sum_cov']
                n = layer_covs[expert_idx]['count']
                cov = cov_sum / max(n, 1)

                proj, _ = layer_editor.projection_manager._compute_null_space_projection_from_cov(cov)
                layer_editor.projection_manager._save_projection_matrix(expert_idx, proj)
                layer_editor.projection_manager._save_covariance_matrix(expert_idx, cov)

                projection_matrices[expert_idx] = proj
                computed_count += 1

                # Accumulate for unified
                if unified_sum is None:
                    unified_sum = cov_sum.clone()
                else:
                    unified_sum += cov_sum
                unified_count += n
            else:
                # Identity fallback
                if identity_count < 5:
                    info(f"No covariance accumulated for expert {expert_idx}, using identity")
                identity_matrix = torch.eye(layer_editor.d_hidden, device=layer_editor.device, dtype=torch.float32)
                layer_editor.projection_manager._save_projection_matrix(expert_idx, identity_matrix)
                projection_matrices[expert_idx] = identity_matrix
                identity_count += 1

            expert_progress.set_postfix({'cached': cached_count, 'computed': computed_count, 'identity': identity_count})

        # Unified matrix from accumulated covariances
        try:
            if unified_sum is not None and unified_count > 0:
                unified_cov = unified_sum / unified_count
                unified_proj, _ = layer_editor.projection_manager._compute_null_space_projection_from_cov(unified_cov)
                layer_editor.projection_manager._save_unified_projection_matrix(unified_proj)
                layer_editor.projection_manager._save_unified_covariance_matrix(unified_cov)
                info(f"✅ Layer {layer_editor.layer_idx}: Unified projection (streaming) saved, shape {tuple(unified_proj.shape)}")
            else:
                info(f"Layer {layer_editor.layer_idx}: No unified covariance (streaming) to compute")
        except Exception as e:
            warning(f"Failed to compute/save streaming unified matrix for layer {layer_editor.layer_idx}: {e}")

        return projection_matrices


    def _process_batch_for_all_layers(
        self,
        batch_texts: List[str],
        layer_indices: List[int],
        collector
    ) -> Dict[int, Dict[int, List[torch.Tensor]]]:
        """
        Process a batch of texts and collect keys for all layers

        Args:
            batch_texts: List of texts to process
            layer_indices: List of layer indices
            collector: Multi-layer collector

        Returns:
            Dictionary mapping layer_idx -> expert_idx -> list of key tensors
        """
        from .utils import find_subject_token_positions
        import torch

        # Prepare batch inputs
        batch_subject_positions = []
        valid_texts = []

        for text in batch_texts:
            try:
                # For Wikipedia-style texts, use a simple heuristic to find subject position
                # Tokenize the text to find a reasonable subject position
                tokens = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
                # Use the middle token as a reasonable subject position for key collection
                seq_len = tokens['input_ids'].shape[1]
                subject_pos = min(seq_len // 2, seq_len - 1)

                if subject_pos >= 0 and seq_len > 1:
                    batch_subject_positions.append(subject_pos)
                    valid_texts.append(text)
            except:
                continue

        if not valid_texts:
            return {layer_idx: {} for layer_idx in layer_indices}

        # Tokenize batch
        batch_inputs = self.tokenizer(
            valid_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.device)

        # Clear collector state
        collector.clear_captured_keys()

        # Single forward pass for all layers
        with torch.no_grad():
            self.model(**batch_inputs)

        # Extract keys for each layer and expert
        batch_keys = {layer_idx: {expert_idx: [] for expert_idx in range(128)}
                     for layer_idx in layer_indices}

        for sample_idx, subject_pos in enumerate(batch_subject_positions):
            if subject_pos < 0:
                continue

            for layer_idx in layer_indices:
                for expert_idx in range(128):
                    key = collector.get_key_for_layer_expert_and_position(
                        layer_idx, expert_idx, subject_pos,
                        self.layer_editors[layer_idx].d_hidden,
                        batch_idx=sample_idx
                    )

                    if key is not None:
                        if key.dim() == 1:
                            key = key.unsqueeze(0)
                        batch_keys[layer_idx][expert_idx].append(key)

        return batch_keys

    def _collect_cov_stats_for_all_layers(
        self,
        layer_indices: List[int],
        num_samples: int,
        force_recompute: bool = False
    ) -> Dict[int, Dict[str, Any]]:
        """
        Streamed second-moment (covariance) accumulation for multiple layers.
        For each batch, update per-expert and unified sum of X^T X and counts, avoiding
        storing all keys in memory. Final projection matrices are computed later from these stats.

        Returns per-layer stats dict with keys:
            - expert_sum_cov: Dict[expert_idx, Tensor[d,d]]
            - expert_counts: Dict[expert_idx, int]
            - unified_sum_cov: Tensor[d,d]
            - unified_count: int
        """
        from .stats_collector import MultiLayerDownProjCollector
        import torch

        info(f"🚀 Streaming covariance accumulation for {len(layer_indices)} layers")
        info(f"🔍 DEBUG: Requested num_samples = {num_samples}")

        # Get Wikipedia-style texts (use first layer's manager for data loading)
        first_layer_manager = self.layer_editors[layer_indices[0]].projection_manager
        texts = first_layer_manager.get_wikipedia_style_texts(num_samples)
        info(f"🔍 DEBUG: Actually loaded {len(texts)} texts")

        # Initialize per-layer stats structures
        layer_stats: Dict[int, Dict[str, Any]] = {}
        for layer_idx in layer_indices:
            d = self.layer_editors[layer_idx].d_intermediate
            layer_stats[layer_idx] = {
                'expert_sum_cov': {},      # expert_idx -> Tensor[d,d]
                'expert_counts': {},       # expert_idx -> int
                'unified_sum_cov': torch.zeros(d, d),  # accumulate on CPU by default
                'unified_count': 0,
            }

        # Create multi-layer collector
        target_layers = [self.layer_editors[layer_idx].target_layer for layer_idx in layer_indices]
        collector = MultiLayerDownProjCollector(target_layers, layer_indices)

        try:
            # Setup hooks for all layers
            collector.setup_hooks()

            # Process texts in batches
            batch_size = 100
            total_batches = (len(texts) + batch_size - 1) // batch_size
            info(f"Processing {len(texts)} texts in {total_batches} batches (streaming mode)")

            from tqdm import tqdm
            from .logger import get_tqdm_kwargs
            tqdm_kwargs = get_tqdm_kwargs(
                desc="Collecting streaming cov stats",
                total=total_batches,
                unit="batch"
            )
            batch_progress = tqdm(range(0, len(texts), batch_size), **tqdm_kwargs)

            for batch_start in batch_progress:
                batch_end = min(batch_start + batch_size, len(texts))
                batch_texts = texts[batch_start:batch_end]

                # Process batch and collect per-layer per-expert keys (lists of vectors)
                batch_keys = self._process_batch_for_all_layers(
                    batch_texts, layer_indices, collector
                )

                # Update streamed covariance accumulators
                for layer_idx in layer_indices:
                    for expert_idx in range(128):
                        key_list = batch_keys.get(layer_idx, {}).get(expert_idx, [])
                        if not key_list:
                            continue
                        Kb = torch.cat(key_list, dim=0).float()  # [n_b, d]
                        cov_inc = Kb.T @ Kb  # [d, d]

                        # Per-expert accumulators
                        exp_sum_cov = layer_stats[layer_idx]['expert_sum_cov'].get(expert_idx)
                        if exp_sum_cov is None:
                            layer_stats[layer_idx]['expert_sum_cov'][expert_idx] = cov_inc.cpu()
                        else:
                            layer_stats[layer_idx]['expert_sum_cov'][expert_idx] = exp_sum_cov + cov_inc.cpu()

                        prev_cnt = layer_stats[layer_idx]['expert_counts'].get(expert_idx, 0)
                        layer_stats[layer_idx]['expert_counts'][expert_idx] = prev_cnt + Kb.shape[0]

                        # Unified accumulators per layer
                        layer_stats[layer_idx]['unified_sum_cov'] += cov_inc.cpu()
                        layer_stats[layer_idx]['unified_count'] += Kb.shape[0]

                # Progress info
                total_counts = sum(
                    sum(layer_stats[l]['expert_counts'].values()) for l in layer_indices
                )
                batch_progress.set_postfix({
                    'samples_seen': total_counts,
                    'batch_size': len(batch_texts)
                })

            # Log statistics per layer
            for layer_idx in layer_indices:
                expert_counts = layer_stats[layer_idx]['expert_counts']
                experts_with_data = sum(1 for c in expert_counts.values() if c > 0)
                total_samples = sum(expert_counts.values())
                info(f"Layer {layer_idx}: {experts_with_data}/128 experts, {total_samples} total keys (streaming)")

            return layer_stats

        finally:
            collector.cleanup()


    def _compute_projection_matrices_from_keys(
        self,
        layer_editor,
        expert_indices: List[int],
        layer_keys: Dict[int, torch.Tensor],
        force_recompute: bool = False
    ) -> Dict[int, torch.Tensor]:
        """
        Compute projection matrices from pre-collected keys

        Args:
            layer_editor: Single layer editor
            expert_indices: List of expert indices
            layer_keys: Pre-collected keys for this layer
            force_recompute: Whether to force recomputation

        Returns:
            Dictionary mapping expert_idx -> projection_matrix
        """
        from tqdm import tqdm
        import torch

        projection_matrices = {}

        # Count experts with keys for progress tracking
        experts_with_keys = len(layer_keys)

        expert_progress = tqdm(
            expert_indices,
            desc=f"Layer {layer_editor.layer_idx} projection matrices",
            unit="expert"
        )

        cached_count = 0
        computed_count = 0
        identity_count = 0

        for expert_idx in expert_progress:
            # Check cache first (unless force_recompute)
            if not force_recompute:
                cached_matrix = layer_editor.projection_manager._load_projection_matrix(expert_idx)
                if cached_matrix is not None:
                    projection_matrices[expert_idx] = cached_matrix
                    cached_count += 1
                    expert_progress.set_postfix({
                        'cached': cached_count,
                        'computed': computed_count,
                        'identity': identity_count
                    })
                    continue

            # Use pre-collected keys if available
            if expert_idx in layer_keys:
                keys = layer_keys[expert_idx]
                projection_matrix, null_space_dim, cov_matrix = layer_editor.projection_manager._compute_null_space_projection_with_cov(keys)

                # Save to cache
                layer_editor.projection_manager._save_projection_matrix(expert_idx, projection_matrix)
                layer_editor.projection_manager._save_covariance_matrix(expert_idx, cov_matrix)

                projection_matrices[expert_idx] = projection_matrix
                computed_count += 1
            else:
                # Fallback to identity if no keys collected (reduce logging)
                if identity_count < 5:  # Only log first few identity matrices
                    info(f"No keys collected for expert {expert_idx}, using identity matrix")
                identity_matrix = torch.eye(
                    layer_editor.d_hidden,
                    device=layer_editor.device,
                    dtype=torch.float32
                )

                # IMPORTANT: Save identity matrix to cache to avoid recomputation
                layer_editor.projection_manager._save_projection_matrix(expert_idx, identity_matrix)

                projection_matrices[expert_idx] = identity_matrix
                identity_count += 1

            # Update progress bar
            expert_progress.set_postfix({
                'cached': cached_count,
                'computed': computed_count,
                'identity': identity_count
            })

        # Compute and save unified projection matrix from all collected keys
        try:
            all_keys_list = []
            expert_key_counts = {}
            for expert_idx in expert_indices:
                if expert_idx in layer_keys and layer_keys[expert_idx].numel() > 0:
                    expert_key_counts[expert_idx] = layer_keys[expert_idx].shape[0]
                    all_keys_list.append(layer_keys[expert_idx])

            if all_keys_list:
                all_keys = torch.cat(all_keys_list, dim=0)
                total_keys = all_keys.shape[0]
                info(f"Computing unified projection matrix for layer {layer_editor.layer_idx} from {total_keys} total keys across {len(expert_key_counts)} experts")

                unified_matrix, _, unified_cov = layer_editor.projection_manager._compute_null_space_projection_with_cov(all_keys)
                layer_editor.projection_manager._save_unified_projection_matrix(unified_matrix)
                layer_editor.projection_manager._save_unified_covariance_matrix(unified_cov)
                info(f"✅ Layer {layer_editor.layer_idx}: Unified projection matrix computed and saved (shape: {unified_matrix.shape})")
            else:
                info(f"No keys collected for layer {layer_editor.layer_idx}, skipping unified projection matrix computation")
        except Exception as e:
            warning(f"Failed to compute/save unified projection matrix for layer {layer_editor.layer_idx}: {e}")

        return projection_matrices

    def edit_knowledge_multi_layer(
        self,
        requests: List[EditRequest],
        method: str = None,  # Use config default
        num_passes: int = None,  # Use config default
        verify_targets: bool = None,  # Use config default
        num_prefixes: int = None,  # Use config default
        use_prefix_expansion: bool = None,  # Use config default
        incremental: bool = None,  # Use config default
        force_recompute_targets: bool = False,
        force_recompute_projections: bool = False
    ):
        """
        Main multi-layer knowledge editing entry point

        Implements the sequential multi-layer update strategy:
        1. Compute unified target vectors using target layer
        2. Compute projection matrices for all editing layers
        3. For each batch, update layers sequentially with appropriate residual weights

        Args:
            requests: List of edit requests
            method: Optimization method ('bcd' or 'gd')
            num_passes: Number of optimization passes per layer
            verify_targets: Whether to verify target vectors
            num_prefixes: Number of prefixes for expansion
            use_prefix_expansion: Whether to use prefix expansion
            incremental: Whether to use incremental updates
            force_recompute_targets: Force recomputation of target vectors
            force_recompute_projections: Force recomputation of projection matrices
        """

        # Use config defaults for None parameters
        if method is None:
            method = getattr(self.config, 'method', 'bcd')
        if num_passes is None:
            num_passes = getattr(self.config, 'num_passes', 2)
        if verify_targets is None:
            verify_targets = getattr(self.config, 'verify_targets', False)
        if num_prefixes is None:
            num_prefixes = getattr(self.config, 'num_prefixes', 10)
        if use_prefix_expansion is None:
            use_prefix_expansion = getattr(self.config, 'enable_prefix_expansion', True)
        if incremental is None:
            incremental = getattr(self.config, 'incremental_editing', False)

        info(f"Starting multi-layer knowledge editing for {len(requests)} requests")
        info(f"Editing layers: {self.editing_layers}")
        info(f"Target layer: {self.target_layer_idx}")

        # Step 1: Compute unified target vectors using target layer
        info("Step 1: Computing unified target vectors...")
        target_vectors = self.compute_unified_target_vectors(
            requests=requests,
            force_recompute=force_recompute_targets
        )

        # Step 2: Compute projection matrices for all editing layers (if not using identity)
        if self.config.use_identity:
            info("Step 2: Using identity projection matrices (skipping computation)")
            projection_matrices = {}
        else:
            info("Step 2: Computing projection matrices for all editing layers...")
            projection_matrices = self.compute_all_projection_matrices(
                requests=requests,
                num_samples=None,  # Use config value
                force_recompute=force_recompute_projections
            )

        # Step 3: Sequential layer updates
        info("Step 3: Performing sequential layer updates...")

        total_layers = len(self.editing_layers)

        for i, layer_idx in enumerate(self.editing_layers):
            layer_position = i + 1  # 1-based position
            remaining_layers = total_layers - i  # How many layers left to update (including current)
            info(f"Updating layer {layer_idx} ({layer_position}/{total_layers})")

            # Calculate residual weight for this layer
            # For layers [11, 12, 13] with target layer 13:
            # Layer 11: r_11 = (target_vector - h_11) / sqrt(3) (3 layers remaining)
            # Layer 12: r_12 = (target_vector - h_12) / sqrt(2) (2 layers remaining)
            # Layer 13: r_13 = (target_vector - h_13) / sqrt(1) (1 layer remaining)
            residual_weight = 1.0 / math.sqrt(remaining_layers)

            info(f"Layer {layer_idx}: residual weight = 1/sqrt({remaining_layers}) = {residual_weight:.3f}")

            # Perform single-layer edit with modified target vectors
            self._edit_single_layer_with_residual_weight(
                layer_idx=layer_idx,
                requests=requests,
                target_vectors=target_vectors,
                residual_weight=residual_weight,
                method=method,
                num_passes=num_passes,
                verify_targets=verify_targets,
                num_prefixes=num_prefixes,
                use_prefix_expansion=use_prefix_expansion,
                incremental=incremental
            )

            info(f"Completed update for layer {layer_idx}")

        info("Multi-layer knowledge editing completed successfully")

    def _edit_single_layer_with_residual_weight(
        self,
        layer_idx: int,
        requests: List[EditRequest],
        target_vectors: Dict[int, ComputeTargetVectorResult],
        residual_weight: float,
        method: str = 'bcd',
        num_passes: int = 2,
        verify_targets: bool = False,
        num_prefixes: int = 10,
        use_prefix_expansion: bool = None,
        incremental: bool = False
    ):
        """
        Edit a single layer with weighted residual computation

        Args:
            layer_idx: Layer index to edit
            requests: List of edit requests
            target_vectors: Unified target vectors computed from target layer
            residual_weight: Weight for residual computation (e.g., 1/3, 1/2, 1/1)
            Other args: Same as edit_knowledge
        """
        info(f"Editing layer {layer_idx} with residual weight {residual_weight}")

        # Get the layer-specific editor
        layer_editor = self.layer_editors[layer_idx]

        # Modify target vectors for this layer by applying residual weight
        # This simulates: r_i = (target_vector - h_i) * residual_weight
        # We'll handle this in the statistics collection phase

        # Store the residual weight in the editor for use during optimization
        layer_editor.current_residual_weight = residual_weight
        layer_editor.unified_target_vectors = target_vectors

        # Perform the edit using the layer-specific editor
        layer_editor.edit_knowledge(
            requests=requests,
            method=method,
            num_passes=num_passes,
            verify_targets=verify_targets,
            num_prefixes=num_prefixes,
            use_prefix_expansion=use_prefix_expansion,
            incremental=incremental
        )

        # Clean up temporary attributes
        if hasattr(layer_editor, 'current_residual_weight'):
            delattr(layer_editor, 'current_residual_weight')
        if hasattr(layer_editor, 'unified_target_vectors'):
            delattr(layer_editor, 'unified_target_vectors')
