"""
Simplified configuration for MoE Knowledge Editing
"""

from dataclasses import dataclass


@dataclass
class MoEEditConfig:
    """Unified configuration for MoE Knowledge Editing"""
    # Model settings
    model_path: str = "/root/autodl-tmp/Qwen3-30B-A3B"
    device: str = "cuda"

    # Editor settings
    layer_idx: int = 13  # For backward compatibility with single-layer editing
    layers: list = None  # For multi-layer editing, e.g., [11, 12, 13]
    lambda_reg: float = 0.15
    projection_threshold: float = 0.02
    nullspace_threshold: float = 0.02
    stats_dir: str = "./qwen3_moe_stats"
    # Caches
    target_cache_dir: str = "./target_vector_cache"
    projection_cache_dir: str = "./projection_matrix_cache"  # Directory to cache per-expert projection matrices
    # Dataset name for target vector cache partitioning (e.g., 'zsre', 'counterfact')
    dataset_name: str | None = None
    num_layers: int = 48
    num_experts: int = 128
    num_experts_per_tok: int = 8
    d_model: int = 2048
    d_intermediate: int = 6144
    d_hidden: int = 768

    # Multi-layer editing settings
    multi_layer_enabled: bool = False
    target_layer_for_unified_vector: int = None  # Which layer to use for unified target vector computation
    layer_update_weights: dict = None  # Custom weights for each layer, e.g., {11: 1/3, 12: 1/2, 13: 1/1}

    # Projection settings
    use_identity: bool = False
    num_samples: int = 1000
    use_expert_projection: bool = True
    # When True, multi-layer editor will use the per-layer unified projection matrix
    # (broadcast to all experts) if available; otherwise falls back to per-expert matrices.
    use_unified_projection: bool = False

    # Projection computation mode
    # If True, accumulate per-batch second-moment (covariance) statistics instead of storing all keys,
    # and compute the projection matrices once at the end using the aggregated stats.
    streaming_covariance: bool = True

    # Optimization settings
    method: str = "bcd"
    num_passes: int = 5
    learning_rate: float = 0.0001
    verify_targets: bool = False

    # Prefix settings
    prefix_enabled: bool = True
    enable_prefix_expansion: bool = True  # Alternative field name for compatibility
    num_prefixes: int = 50

    # Evaluation settings
    eval_batch_size: int = 4
    compare_pre_post: bool = True

    # Advanced editing settings
    incremental_editing: bool = False
    update_projection_matrices: bool = False

    # Token positioning strategy (AlphaEdit compatible)
    fact_token: str = "subject_last"  # Options: "last", "subject_last", "subject_first", "subject_first_after_last"

    def get_editing_layers(self):
        """Get list of layers to edit"""
        if self.multi_layer_enabled and self.layers:
            return self.layers
        else:
            return [self.layer_idx]

    def get_target_layer_for_vector_computation(self):
        """Get the layer to use for unified target vector computation"""
        if self.multi_layer_enabled:
            if self.target_layer_for_unified_vector is not None:
                return self.target_layer_for_unified_vector
            else:
                # Default to the highest layer in editing layers
                editing_layers = self.get_editing_layers()
                return max(editing_layers)
        else:
            return self.layer_idx

    def is_multi_layer_enabled(self):
        """Check if multi-layer editing is enabled and properly configured"""
        return (self.multi_layer_enabled and
                self.layers is not None and
                len(self.layers) > 1)

    @classmethod
    def from_json(cls, json_path: str):
        """Load configuration from JSON file"""
        import json
        with open(json_path, 'r') as f:
            data = json.load(f)

        # Handle layers field - ensure it's a list
        if 'layers' in data and isinstance(data['layers'], list):
            data['multi_layer_enabled'] = len(data['layers']) > 1

        # Backward/alias compatibility: allow 'unified' to control 'use_unified_projection'
        if 'unified' in data and 'use_unified_projection' not in data:
            try:
                data['use_unified_projection'] = bool(data['unified'])
            except Exception:
                pass

        # Create config object
        config = cls()
        for key, value in data.items():
            if hasattr(config, key):
                setattr(config, key, value)

        return config







