"""
Statistics collection and hook management for MoE layers
"""

import torch
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from .manager import HookManager


class Qwen3MoEStatisticsCollector:
    """Statistics collector for Qwen3 MoE layers"""

    def __init__(self):
        self.clear()

    def clear(self):
        """Clear all collected statistics"""
        self.router_logits = []
        self.expert_weights = []
        self.expert_indices = []
        self.expert_inputs = {}  # expert_idx -> list of inputs to down_proj
        self.moe_inputs = []     # inputs to the whole MoE layer

    def create_hooks(self, moe_layer):
        """Create forward hooks for Qwen3/gpt-oss MoE layer (supports gate or router)"""
        hooks = []

        # Hook for MoE layer input
        def moe_input_hook(module, input, output):
            self.moe_inputs.append(input[0].detach())

        hooks.append(moe_layer.register_forward_hook(moe_input_hook))

        # Hook for router (gate) to capture routing decisions
        def gate_hook(module, input, output):
            # Support both tensor and tuple outputs (e.g., gpt-oss returns (scores, indices))
            try:
                logits = output[0] if isinstance(output, tuple) and len(output) > 0 else output
                if hasattr(logits, 'detach'):
                    self.router_logits.append(logits.detach())
                else:
                    # Fallback: store as-is for debugging
                    self.router_logits.append(logits)
            except Exception:
                # Be robust: skip if unexpected structure
                pass

        if hasattr(moe_layer, 'gate'):
            hooks.append(moe_layer.gate.register_forward_hook(gate_hook))
        elif hasattr(moe_layer, 'router'):
            hooks.append(moe_layer.router.register_forward_hook(gate_hook))

        return hooks


class DownProjHookCollector:
    """Specialized collector for down_proj layer inputs with similarity-based dynamic matching"""

    def __init__(self, target_layer, num_experts: int):
        self.target_layer = target_layer
        self.num_experts = num_experts
        self.captured_keys = {}  # expert_idx -> list of captured keys
        self.hook_manager = HookManager()

        # Fields for similarity-based matching
        self.input_hidden_states = None  # Hidden states before MoE layer
        self.expert_input_vectors = {}  # expert_idx -> {internal_idx: vector}
        self.similarity_cache = {}  # Cache for similarity computations
        self.input_hook_handle = None  # Handle for input capture hook

        # Optimization: track which experts we want to collect data for
        self.target_experts = None  # Set of expert indices to collect data for

    def set_target_experts(self, expert_indices: List[int]):
        """Set which experts to collect data for (optimization)"""
        self.target_experts = set(expert_indices) if expert_indices else None



    def setup_input_capture_hook(self):
        """Setup hook to capture input hidden states before MoE layer"""
        def input_hook(module, input, output):
            # Capture hidden states entering the MoE layer (after layer norm)
            self.input_hidden_states = input[0].detach().clone()

        # Register hook on the MoE layer
        self.input_hook_handle = self.target_layer.register_forward_hook(input_hook)

    def setup_hooks(self):
        """Setup hooks for all down_proj layers and input capture"""
        # Setup input capture hook first
        self.setup_input_capture_hook()

        def create_expert_input_hook(expert_idx):
            """Hook to capture 2048-dim input vectors to experts (for similarity matching)"""
            def hook(module, input, output):
                # Skip if this expert is not in our target list (optimization)
                if self.target_experts is not None and expert_idx not in self.target_experts:
                    return

                # input[0] contains the 2048-dim vectors assigned to this expert
                input_tensor = input[0].detach().clone()

                if expert_idx not in self.expert_input_vectors:
                    self.expert_input_vectors[expert_idx] = {}

                # Store 2048-dim input vectors for similarity computation
                # Use a counter to ensure unique indices even in batch processing
                current_count = len(self.expert_input_vectors[expert_idx])
                for i in range(input_tensor.shape[0]):
                    self.expert_input_vectors[expert_idx][current_count + i] = input_tensor[i, :].clone()

            return hook

        def create_down_proj_hook(expert_idx):
            """Hook to capture 768-dim key vectors from down_proj (for knowledge editing)"""
            def hook(module, input, output):
                # Skip if this expert is not in our target list (optimization)
                if self.target_experts is not None and expert_idx not in self.target_experts:
                    return

                # input[0] is the 768-dim key vector (SwiGLU output)
                input_tensor = input[0].detach().clone()

                if expert_idx not in self.captured_keys:
                    self.captured_keys[expert_idx] = []

                # Store all captured keys for this expert
                # In batch processing, we get multiple keys at once
                for i in range(input_tensor.shape[0]):
                    self.captured_keys[expert_idx].append(input_tensor[i, :].clone())

            return hook

        # Register hooks for all experts
        for expert_idx, expert in enumerate(self.target_layer.experts):
            if hasattr(expert, 'up_proj'):
                # Hook on up_proj to capture 2048-dim input vectors
                self.hook_manager.register(expert.up_proj, create_expert_input_hook(expert_idx))

            if hasattr(expert, 'down_proj'):
                # Hook on down_proj to capture 768-dim key vectors
                self.hook_manager.register(expert.down_proj, create_down_proj_hook(expert_idx))

    def setup_hooks_for_experts(self, expert_indices: List[int]):
        """Setup hooks only for specified experts (optimized version)"""
        # Setup input capture hook first
        self.setup_input_capture_hook()

        def create_expert_input_hook(expert_idx):
            """Hook to capture 2048-dim input vectors to experts (for similarity matching)"""
            def hook(module, input, output):
                # input[0] contains the 2048-dim vectors assigned to this expert
                input_tensor = input[0].detach().clone()

                if expert_idx not in self.expert_input_vectors:
                    self.expert_input_vectors[expert_idx] = {}

                # Store 2048-dim input vectors for similarity computation
                for i in range(input_tensor.shape[0]):
                    self.expert_input_vectors[expert_idx][i] = input_tensor[i, :].clone()

            return hook

        def create_down_proj_hook(expert_idx):
            """Hook to capture 768-dim key vectors from down_proj (for knowledge editing)"""
            def hook(module, input, output):
                # input[0] is the 768-dim key vector (SwiGLU output)
                input_tensor = input[0].detach().clone()

                if expert_idx not in self.captured_keys:
                    self.captured_keys[expert_idx] = []

                # Store all captured keys for this expert
                for i in range(input_tensor.shape[0]):
                    self.captured_keys[expert_idx].append(input_tensor[i, :].clone())

            return hook

        # Register hooks only for specified experts
        for expert_idx in expert_indices:
            if expert_idx < len(self.target_layer.experts):
                expert = self.target_layer.experts[expert_idx]
                if hasattr(expert, 'up_proj'):
                    # Hook on up_proj to capture 2048-dim input vectors
                    self.hook_manager.register(expert.up_proj, create_expert_input_hook(expert_idx))

                if hasattr(expert, 'down_proj'):
                    # Hook on down_proj to capture 768-dim key vectors
                    self.hook_manager.register(expert.down_proj, create_down_proj_hook(expert_idx))

    def clear_captured_keys(self):
        """Clear captured keys and similarity data"""
        self.captured_keys.clear()
        self.input_hidden_states = None
        self.expert_input_vectors.clear()
        self.similarity_cache.clear()

    def _compute_similarity_mapping(self, expert_idx: int, target_position: int, batch_idx: int = 0) -> Optional[int]:
        """Compute similarity-based mapping for expert and position with batch support"""

        if (expert_idx not in self.expert_input_vectors or
            self.input_hidden_states is None):
            return None

        # Check cache first
        cache_key = (expert_idx, target_position, batch_idx)
        if cache_key in self.similarity_cache:
            return self.similarity_cache[cache_key]

        # Get target position's input vector for the specific batch sample
        if (batch_idx >= self.input_hidden_states.shape[0] or
            target_position >= self.input_hidden_states.shape[1]):
            return None

        target_vector = self.input_hidden_states[batch_idx, target_position, :].float()

        # Compute similarities with all expert internal vectors
        best_similarity = -1.0
        best_internal_idx = None

        for internal_idx, expert_vector in self.expert_input_vectors[expert_idx].items():
            similarity = F.cosine_similarity(
                target_vector.unsqueeze(0),
                expert_vector.float().unsqueeze(0)
            ).item()

            if similarity > best_similarity:
                best_similarity = similarity
                best_internal_idx = internal_idx

        # Lower similarity threshold for batch processing to ensure we get matches
        similarity_threshold = 0.95 if batch_idx > 0 else 0.99
        if best_similarity > similarity_threshold:
            self.similarity_cache[cache_key] = best_internal_idx
            return best_internal_idx

        return None

    def _fallback_get_key(self, expert_idx: int, position: int):
        """Fallback method when similarity matching fails"""
        if expert_idx not in self.captured_keys or not self.captured_keys[expert_idx]:
            return None

        # Use first available key as fallback
        return self.captured_keys[expert_idx][0].float()

    def _improved_fallback_get_key(self, expert_idx: int, position: int, batch_idx: int = 0):
        """Improved fallback method that considers batch_idx"""
        if expert_idx not in self.captured_keys or not self.captured_keys[expert_idx]:
            return None

        available_keys = self.captured_keys[expert_idx]

        # Try to use batch_idx to select appropriate key
        if batch_idx < len(available_keys):
            return available_keys[batch_idx].float()

        # If batch_idx is out of range, use modulo to cycle through available keys
        key_idx = batch_idx % len(available_keys)
        return available_keys[key_idx].float()

    def get_key_for_expert_and_position(self, expert_idx: int, position: int, d_k: int, batch_idx: int = 0):
        """Get key for specific expert and token position using similarity-based dynamic matching with batch support"""

        if expert_idx not in self.captured_keys or not self.captured_keys[expert_idx]:
            print(f"Warning: No captured keys for expert {expert_idx}")
            return None

        try:
            # Try similarity-based matching first
            matched_internal_idx = self._compute_similarity_mapping(expert_idx, position, batch_idx)

            if matched_internal_idx is not None and matched_internal_idx < len(self.captured_keys[expert_idx]):
                # Found similarity match, get corresponding key
                raw_key = self.captured_keys[expert_idx][matched_internal_idx].float()
                return raw_key

            # Improved fallback: use batch_idx to select appropriate key
            return self._improved_fallback_get_key(expert_idx, position, batch_idx)

        except Exception as e:
            return self._improved_fallback_get_key(expert_idx, position, batch_idx)

    def cleanup(self):
        """Clean up hooks and data"""
        self.hook_manager.remove_all()
        if self.input_hook_handle is not None:
            self.input_hook_handle.remove()
            self.input_hook_handle = None
        self.clear_captured_keys()


class ProjectionMatrixCollector:
    """Collector for computing projection matrices"""

    def __init__(self, target_layer, d_intermediate: int, device: str):
        self.target_layer = target_layer
        self.d_intermediate = d_intermediate
        self.device = device
        self.captured_keys = {}  # Dict[int, List[Tensor]] - expert_idx -> list of keys
        self.hook_manager = HookManager()

    def setup_hooks(self, expert_indices: Optional[List[int]] = None):
        """Setup hooks to capture down_proj inputs for projection matrix computation

        Args:
            expert_indices: List of expert indices to collect keys for. If None, collect for all experts.
        """
        if expert_indices is None:
            expert_indices = list(range(len(self.target_layer.experts)))

        def create_down_proj_hook(expert_idx):
            def hook(module, input, output):
                # input[0] is down_proj input, i.e., SwiGLU output (gate * silu(up))
                key = input[0].detach()
                if expert_idx not in self.captured_keys:
                    self.captured_keys[expert_idx] = []
                self.captured_keys[expert_idx].append(key)
            return hook

        # Register hooks for specified experts' down_proj layers
        for expert_idx in expert_indices:
            if expert_idx < len(self.target_layer.experts):
                expert = self.target_layer.experts[expert_idx]
                if hasattr(expert, 'down_proj'):
                    self.hook_manager.register(expert.down_proj, create_down_proj_hook(expert_idx))


    def get_keys_for_expert(self, expert_idx: int) -> Optional[torch.Tensor]:
        """Get keys for a specific expert as a concatenated tensor

        Args:
            expert_idx: Expert index

        Returns:
            Concatenated tensor of shape (N, d) where N is total number of samples, or None if no keys
        """
        if expert_idx not in self.captured_keys or not self.captured_keys[expert_idx]:
            return None

        # Concatenate tensors instead of stacking to handle different batch sizes
        keys_list = self.captured_keys[expert_idx]
        if len(keys_list) == 1:
            return keys_list[0]
        else:
            return torch.cat(keys_list, dim=0)

    def get_all_keys(self):
        """Get all captured keys as a tensor (for backward compatibility)"""
        all_keys = []
        for expert_idx in sorted(self.captured_keys.keys()):
            all_keys.extend(self.captured_keys[expert_idx])
        if not all_keys:
            return None

        # Concatenate tensors instead of stacking to handle different batch sizes
        if len(all_keys) == 1:
            return all_keys[0]
        else:
            return torch.cat(all_keys, dim=0)

    def cleanup(self):
        """Clean up hooks"""

class GptOssExpertsForwardKeyCollector:
    """Forward wrapper for GptOssExperts to capture gated outputs (keys) before down_proj.

    This avoids relying on per-expert nn.Linear modules. It re-implements the forward
    exactly as in GptOssExperts and records (up+1)*glu for specified flattened token indices.
    """
    def __init__(self, experts_module):
        self.experts = experts_module
        self._orig_forward = None
        # Capture: expert_idx -> dict(flat_idx -> tensor[d_k])
        self.captured: Dict[int, Dict[int, torch.Tensor]] = {}
        # Set of flattened token indices we want to capture in the next call
        self.capture_flat_indices: Optional[set[int]] = None
        # Cache last batch size and seq_len for convenience (optional)
        self.last_batch_size: Optional[int] = None
        self.last_seq_len: Optional[int] = None

    def set_capture_positions(self, flat_indices: List[int]):
        self.capture_flat_indices = set(int(i) for i in flat_indices)

    def clear(self):
        self.captured.clear()
        self.capture_flat_indices = None

    def get_key_for_expert_and_flat_index(self, expert_idx: int, flat_idx: int) -> Optional[torch.Tensor]:
        if expert_idx in self.captured and flat_idx in self.captured[expert_idx]:
            return self.captured[expert_idx][flat_idx]
        return None

    def get_all_keys(self) -> Dict[int, Dict[int, torch.Tensor]]:
        return self.captured

    def attach(self):
        if self._orig_forward is not None:
            return
        self._orig_forward = self.experts.forward

        def wrapped_forward(hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
            # Mirror original implementation while capturing gated outputs
            batch_size = hidden_states.shape[0]
            self.last_batch_size = batch_size
            # Flatten tokens to match expert code convention
            hidden_states_flat = hidden_states.reshape(-1, self.experts.hidden_size)
            num_tokens = hidden_states_flat.shape[0]
            # routing_weights expected shape: (num_tokens, num_experts)
            num_experts = routing_weights.shape[1]

            # Prepare capture storage if needed
            need_capture = self.capture_flat_indices is not None and len(self.capture_flat_indices) > 0
            if need_capture:
                # Init nested dicts lazily
                pass

            if hidden_states_flat.device.type == "cpu" or self.experts.training:
                next_states = torch.zeros_like(hidden_states_flat, dtype=hidden_states_flat.dtype, device=hidden_states_flat.device)
                with torch.no_grad():
                    expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
                    expert_mask = expert_mask.permute(2, 1, 0)
                    expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
                for expert_idx in expert_hit[:]:
                    expert_idx = expert_idx[0]
                    with torch.no_grad():
                        _, token_idx = torch.where(expert_mask[expert_idx])
                    current_state = hidden_states_flat[token_idx]
                    gate_up = current_state @ self.experts.gate_up_proj[expert_idx] + self.experts.gate_up_proj_bias[expert_idx]
                    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
                    gate = gate.clamp(min=None, max=self.experts.limit)
                    up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
                    glu = gate * torch.sigmoid(gate * self.experts.alpha)
                    gated_output = (up + 1) * glu

                    # Capture only requested token indices
                    if need_capture:
                        for ti, fidx in enumerate(token_idx.tolist()):
                            if fidx in self.capture_flat_indices:
                                self.captured.setdefault(int(expert_idx), {})[int(fidx)] = gated_output[ti, :].detach().clone().to(torch.float32).cpu()

                    out = gated_output @ self.experts.down_proj[expert_idx] + self.experts.down_proj_bias[expert_idx]
                    weighted_output = out * routing_weights[token_idx, expert_idx, None]
                    next_states.index_add_(0, token_idx, weighted_output.to(hidden_states_flat.dtype))
                next_states = next_states.view(batch_size, -1, self.experts.hidden_size)
                return next_states
            else:
                # GPU inference path: compute for all experts at once
                # hidden_states_flat: [N, hidden_size]
                N = hidden_states_flat.shape[0]
                hs_rep = hidden_states_flat.repeat(num_experts, 1).view(num_experts, -1, self.experts.hidden_size)
                gate_up = torch.bmm(hs_rep, self.experts.gate_up_proj) + self.experts.gate_up_proj_bias[..., None, :]
                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
                gate = gate.clamp(min=None, max=self.experts.limit)
                up = up.clamp(min=-self.experts.limit, max=self.experts.limit)
                glu = gate * torch.sigmoid(gate * self.experts.alpha)
                gated_output = (up + 1) * glu  # [num_experts, N, expert_dim]

                # Capture requested positions without storing full tensor
                if need_capture:
                    # routing indices are not needed here; we capture by flat index directly
                    for fidx in list(self.capture_flat_indices):
                        if 0 <= fidx < N:
                            # store per expert
                            vecs = gated_output[:, fidx, :]  # [num_experts, d_k]
                            for e in range(num_experts):
                                self.captured.setdefault(int(e), {})[int(fidx)] = vecs[e].detach().clone().to(torch.float32).cpu()

                next_states = torch.bmm(gated_output, self.experts.down_proj)
                next_states = next_states + self.experts.down_proj_bias[..., None, :]
                next_states = next_states.view(num_experts, batch_size, -1, self.experts.hidden_size)
                next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
                next_states = next_states.sum(dim=0)
                return next_states

        self.experts.forward = wrapped_forward

    def detach(self):
        if self._orig_forward is not None:
            self.experts.forward = self._orig_forward
            self._orig_forward = None
        self.clear()

        self.hook_manager.remove_all()


class MultiLayerDownProjCollector:
    """
    Optimized collector for multiple layers in a single forward pass
    Collects down_proj keys from multiple layers simultaneously
    """

    def __init__(self, target_layers: List, layer_indices: List[int]):
        """
        Initialize multi-layer collector

        Args:
            target_layers: List of target layer modules
            layer_indices: List of corresponding layer indices
        """
        self.target_layers = target_layers
        self.layer_indices = layer_indices
        self.layer_to_index = {id(layer): idx for layer, idx in zip(target_layers, layer_indices)}

        # Storage for captured keys from all layers
        self.captured_keys = {layer_idx: {} for layer_idx in layer_indices}
        self.hook_manager = HookManager()

    def setup_hooks(self):
        """Setup hooks for all target layers"""
        import logging
        logger = logging.getLogger(__name__)

        def info(msg):
            logger.info(msg)

        total_hooks = 0
        for layer, layer_idx in zip(self.target_layers, self.layer_indices):
            layer_hooks = 0
            for expert_idx in range(len(layer.experts)):
                expert = layer.experts[expert_idx]
                down_proj = expert.down_proj

                # Create hook function that captures layer_idx and expert_idx
                def create_hook(l_idx, e_idx):
                    def hook_fn(module, input, output):
                        # input[0] is the key tensor: [batch_size, seq_len, d_intermediate]
                        if len(input) > 0 and input[0] is not None:
                            key_tensor = input[0].detach()

                            # Store with layer and expert identification
                            if l_idx not in self.captured_keys:
                                self.captured_keys[l_idx] = {}
                            if e_idx not in self.captured_keys[l_idx]:
                                self.captured_keys[l_idx][e_idx] = []

                            self.captured_keys[l_idx][e_idx].append(key_tensor)
                    return hook_fn

                hook_fn = create_hook(layer_idx, expert_idx)
                self.hook_manager.register(down_proj, hook_fn)
                layer_hooks += 1
                total_hooks += 1

            info(f"Setup {layer_hooks} hooks for layer {layer_idx}")

        info(f"Total hooks setup: {total_hooks} across {len(self.layer_indices)} layers")

    def clear_captured_keys(self):
        """Clear all captured keys"""
        for layer_idx in self.layer_indices:
            self.captured_keys[layer_idx] = {}

    def get_key_for_layer_expert_and_position(
        self,
        layer_idx: int,
        expert_idx: int,
        position: int,
        d_intermediate: int,
        batch_idx: int = 0
    ) -> Optional[torch.Tensor]:
        """
        Get key for specific layer, expert, and position

        Args:
            layer_idx: Layer index
            expert_idx: Expert index
            position: Token position
            d_intermediate: Intermediate dimension
            batch_idx: Batch index

        Returns:
            Key tensor or None if not found
        """
        if (layer_idx not in self.captured_keys or
            expert_idx not in self.captured_keys[layer_idx] or
            not self.captured_keys[layer_idx][expert_idx]):
            return None

        # Get the most recent capture for this expert
        captured_tensors = self.captured_keys[layer_idx][expert_idx]

        for tensor in captured_tensors:
            # tensor shape: [batch_size, seq_len, d_intermediate] or [seq_len, d_intermediate]
            if tensor.dim() == 3:
                # Batch processing case
                if (batch_idx < tensor.shape[0] and
                    position < tensor.shape[1]):

                    # Extract key for specific batch and position
                    key = tensor[batch_idx, position, :]  # [d_intermediate]

                    # Validate key
                    if key.shape[0] == d_intermediate and not torch.isnan(key).any():
                        return key.cpu()
            elif tensor.dim() == 2:
                # Single sample case
                if position < tensor.shape[0]:
                    key = tensor[position, :]  # [d_intermediate]

                    # Validate key
                    if key.shape[0] == d_intermediate and not torch.isnan(key).any():
                        return key.cpu()

        return None

    def cleanup(self):
        """Clean up all hooks"""
        self.hook_manager.remove_all()
        self.clear_captured_keys()
