import torch
from typing import Dict, Optional
from .base import BaseMetric

class XAttnEntropyMetric(BaseMetric):
    @property
    def name(self) -> str:
        return "CrossAttention_Entropy"
        
    @property
    def metric_type(self) -> str:
        return "per_seed"

    @torch.no_grad()
    def measure(self, **kwargs) -> Dict:
        """
        Calculates the attention entropy, as per Ren et al. (2024).
        We use the faster version (E_t=T^l) from the paper.
        """
        controller = kwargs.get('controller', None)
        if controller is None:
            print("[XAttnEntropyMetric] No controller!")
            return {"entropy": 0}

        """
        Sec 4.3 "The fourth and fifteenth layers have clearer separation, which can distinguish memorization and non-memorization better." 
        Sec 6.2 "For all the detection, we use **l=4**"
        """
        target_layer_name = "down_3" # 4th
        
        # Get attention maps from the first step (t=T)
        if controller is not None: 
            print(f"[XAttnEntropyMetric] Controller Cached Steps {controller.latest_attention_maps.keys()}")
            print(f"[XAttnEntropyMetric] Controller Provides XA from")
            for key, tensor in controller.latest_attention_maps[0].items():
                print(f"{key}: {tensor.shape}")
        attn_maps = controller.latest_attention_maps[0].get(f"{target_layer_name}", None)
        if attn_maps is None:
            print(f"[XAttnEntropyMetric] Warning: Could not find target attention map for key '{target_layer_name}'.")
            return {"entropy": 0}
            
        # We only care about the conditional part of the batch
        cond_attn_map = attn_maps[attn_maps.shape[0] // 2:]
        
        # Average across patches (h*w) and attention heads to get per-token scores.
        # This correctly implements the paper's method of calculating avg attention per token.
        avg_per_token_scores = cond_attn_map.mean(dim=(0, 1))
        
        # Normalize to create a probability distribution
        dist = avg_per_token_scores / (avg_per_token_scores.sum() + 1e-9)
        print(f"[XAttnEntropyMetric] Shape of Distritbuion ", dist.shape)
        
        # Calculate entropy
        entropy = -torch.sum(dist * torch.log(dist + 1e-9)).item()
        
        print(self.name, entropy)
        return {"entropy": entropy}