#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn

class InterventionModule(nn.Module):
    """Applies intervention effects to attention scores.
    
    This module modifies attention scores based on a specified intervention,
    typically simulating the effect of a do-operation from causal inference.
    A common approach is to mask attention *to* the intervened node,
    as its value is considered fixed externally.
    """
    def __init__(self, d_model=None):
        """Initialize the InterventionModule.
        
        Args:
            d_model (int, optional): Dimension of the model. Not strictly required
                                     for simple masking-based interventions.
        """
        super().__init__()
        # This simple version doesn't require parameters or d_model

    def forward(self, scores, intervention, G):
        """Modify attention scores based on intervention.
        
        Assumes intervention masks attention *to* the intervened node.

        Args:
            scores (torch.Tensor): Attention scores (batch_size, n_heads, seq_len, seq_len).
            intervention (dict or None): Intervention details. Expected keys:
                                         - 'node_idx' (int): The index of the node to intervene on.
                                         - Potentially 'value' (torch.Tensor) for more complex interventions (not used here).
                                         If None, no intervention is applied.
            G (torch.Tensor): Learned causal graph (batch_size, [n_heads,] seq_len, seq_len).
                              Not used in this simple masking version, but passed for potential future extensions.

        Returns:
            torch.Tensor: Modified attention scores (batch_size, n_heads, seq_len, seq_len).
        """
        if intervention is None or 'node_idx' not in intervention:
            return scores # No intervention specified or details missing

        node_idx = intervention['node_idx']
        batch_size, num_heads, seq_len, _ = scores.shape

        # Validate node index
        if not isinstance(node_idx, int) or not (0 <= node_idx < seq_len):
             # Log a warning or raise error? For robustness, maybe just warn and return original scores.
             print(f"Warning: Intervention node index {node_idx} is invalid for seq_len {seq_len}. Skipping intervention.")
             return scores
             # Alternatively: raise ValueError(f"Intervention node index {node_idx} out of bounds for seq_len {seq_len}")

        # Create a mask to apply to the scores.
        # We want to prevent attention *to* the intervened node (column `node_idx`).
        # We achieve this by setting the scores in that column to negative infinity.
        modified_scores = scores.clone()
        modified_scores[:, :, :, node_idx] = float('-inf')

        # Note: This is a simple interpretation. More complex interventions might involve:
        # - Using the intervention['value'] if provided.
        # - Modifying attention *from* the intervened node based on G.
        # - Implementing specific causal adjustment formulas (like front-door/back-door).

        return modified_scores

# Example Usage
if __name__ == '__main__':
    batch_size = 2
    num_heads = 4
    seq_len = 5
    d_model = 64 # Not used by module itself

    intervention_module = InterventionModule(d_model)
    scores = torch.randn(batch_size, num_heads, seq_len, seq_len)
    G = torch.rand(batch_size, seq_len, seq_len) # Dummy graph, not used here

    print("--- Testing InterventionModule --- ")

    # Test case 1: No intervention
    scores_no_interv = intervention_module(scores.clone(), None, G)
    assert torch.equal(scores_no_interv, scores)
    print("Test Case 1 (No Intervention): Passed")

    # Test case 2: Intervention on node 2
    intervention_details = {'node_idx': 2}
    scores_interv = intervention_module(scores.clone(), intervention_details, G)

    print(f"\nOriginal scores sample (batch 0, head 0, row 0):\n{scores[0, 0, 0, :]}")
    print(f"Scores after intervention on node 2 sample (batch 0, head 0, row 0):\n{scores_interv[0, 0, 0, :]}")

    # Check if column 2 is masked (-inf) across all batches and heads
    assert torch.all(scores_interv[:, :, :, 2] == float('-inf'))
    print("Check: Column 2 is masked. Passed.")

    # Check if other columns are unchanged
    mask_col2 = torch.zeros_like(scores, dtype=torch.bool)
    mask_col2[:, :, :, 2] = True
    # Compare elements where the mask is False (i.e., not in column 2)
    assert torch.all(scores_interv[~mask_col2] == scores[~mask_col2])
    print("Check: Other columns are unchanged. Passed.")
    print("Test Case 2 (Intervention on node 2): Passed")

    # Test case 3: Invalid node index (e.g., out of bounds)
    intervention_invalid = {'node_idx': 10}
    print(f"\nTesting invalid index {intervention_invalid['node_idx']}...")
    scores_invalid = intervention_module(scores.clone(), intervention_invalid, G)
    # Check if scores are unchanged due to warning/error handling
    assert torch.equal(scores_invalid, scores)
    print(f"Test Case 3 (Invalid Index {intervention_invalid['node_idx']}): Passed (Scores unchanged as expected)")

    # Test case 4: Missing 'node_idx' key
    intervention_missing = {'value': 1}
    print(f"\nTesting intervention dict missing 'node_idx': {intervention_missing}")
    scores_missing = intervention_module(scores.clone(), intervention_missing, G)
    assert torch.equal(scores_missing, scores)
    print("Test Case 4 (Missing 'node_idx'): Passed (Scores unchanged as expected)")


