from typing import List, Optional, Union, Dict, Any, Tuple
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F


class FFNSequenceAttributor:
    """Analyze FFN neuron importance for sequence generation using Integrated Gradients."""
    
    def __init__(
        self,
        model: nn.Module,
        steps: int = 10,
        intermediate_size: int = 4096,
        device: Optional[str] = None,
    ):
        self.model = model
        self.num_steps = steps
        self.device = device or next(model.parameters()).device
        self.hooks = {}
        self.intermediate_size = intermediate_size
        
        self.current_layer = None
        self.ffn_activation = None
        self.batch_weights = None
        
    
    def scaling_hook(self, module, input, output):
        num_points = self.num_steps
        # output = output.reshape(num_points, -1, self.intermediate_size)
        ffn_weights = output[-1 , -1, :]
        baseline = torch.zeros_like(ffn_weights)
        step = (ffn_weights - baseline) / num_points
        self.step = step
        res = torch.cat([torch.add(baseline, step * i) for i in range(num_points)], dim=0)  # (num_points, ffn_size)
        res = res.reshape(num_points, self.intermediate_size)
        self.batch_weights = res
        # output = output.repeat(num_points, 1, 1)
        output[:, -1, :] = res
        output = output.reshape(self.num_steps, -1, self.intermediate_size)
        return output
        
    def register_hooks(self, layer_idx):
        """Register forward hooks for each FFN layer."""

        # Clear existing hooks
        self.remove_hooks()
        
        for name, module in self.model.named_modules():
            if "gate_proj" in name and isinstance(module, nn.Linear) and f'.{str(layer_idx)}.' in name:
                    self.hooks[name] = module.register_forward_hook(self.scaling_hook)

    def remove_hooks(self):
        """Remove all hooks."""
        for hook in self.hooks.values():
            hook.remove()
        self.hooks = {}
        
    def compute_sequence_importance(
        self,
        input_ids: torch.Tensor,
        target_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Dict[str, Any]:
        """
        Compute neuron importance scores by iterating through layers and tokens.
        Handles all alphas in a single batch for each layer.
        """
        self.step = None
        self.model.eval()
        sequence_length = target_ids.size(1)

        num_layers = len(self.model.model.layers)
        # Process each target token
        current_input = input_ids
        current_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
        
        layer_scores = []
        
        for pos in tqdm(range(sequence_length)):
            target_token = target_ids[:, pos]
                        
            for layer_idx in range(num_layers):
                self.step = None
                self.batch_weights = None
                self.register_hooks(layer_idx)
                                
                self.model.zero_grad()

                _current_input = current_input.repeat(self.num_steps, 1)
                _current_mask  = current_mask.repeat(self.num_steps, 1)
                
                outputs = self.model(_current_input, attention_mask=_current_mask)
                
                logits = outputs.logits

                # Compute loss for current token (last position)
                next_token_logits = logits[:, -1, :]
                probs = F.softmax(next_token_logits, dim=-1)

                gradient = torch.autograd.grad(torch.unbind(probs[: , target_token.item()]), self.batch_weights)[0].clamp(min=0).sum(dim=0).squeeze()
                # gradient = gradient * self.step

                gradient = gradient / gradient.max()
                layer_scores.append(gradient.unsqueeze(0).detach().cpu())
                
                self.remove_hooks()
            with torch.no_grad():
                current_input = torch.cat([current_input, target_token.unsqueeze(1)], dim=1)
                current_mask = torch.cat([
                    current_mask,
                    torch.ones(current_mask.size(0), 1, device=self.device)
                ], dim=1)
            
        layer_scores = torch.cat(layer_scores, dim=0).unsqueeze(0)

        self.remove_hooks()
        return layer_scores