import torch
import numpy as np
from sklearn.cluster import KMeans

class AttentionPathExtractor:
    def __init__(self, vit_visual, device="cuda", n_clusters=4, max_layers=None):
        self.vit_visual = vit_visual.to(device).float()   # 
        self.device = device
        self.n_clusters = n_clusters
        self.max_layers = max_layers
        self.attention_maps = []
        self.handles = []
        self.dtype = torch.float   # 

    def _hook(self, module, input, output):
        if isinstance(output, tuple):
            out = output[0]
        else:
            out = output
        self.attention_maps.append(out.detach().cpu())

    def register_hooks(self):
        self.attention_maps.clear()
        self.handles = []
        blocks = self.vit_visual.transformer.resblocks
        n = len(blocks) if self.max_layers is None else min(self.max_layers, len(blocks))
        for i in range(n):
            attn = blocks[i].attn
            h = attn.register_forward_hook(self._hook)
            self.handles.append(h)

    def remove_hooks(self):
        for h in self.handles:
            h.remove()
        self.handles.clear()

    def extract(self, image_tensor):
        self.register_hooks()
        with torch.no_grad():
            _ = self.vit_visual(image_tensor.to(self.device, dtype=self.dtype))
        self.remove_hooks()

        num_layers = len(self.attention_maps)
        num_tokens = self.attention_maps[0].shape[-1]

        patch_tracks = []
        for token_idx in range(1, self.attention_maps[0].shape[0]):
            track = []
            for l in range(num_layers):
                attn_map = self.attention_maps[l]
                attn_vec = attn_map[token_idx, 0, :]
                attn = attn_vec.mean().item()
                track.append(attn)
            patch_tracks.append(track)
        patch_tracks = np.array(patch_tracks)

        kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(patch_tracks)
        patch_clusters = kmeans.labels_

        path = []
        for l in range(num_layers):
            attn_map = self.attention_maps[l]
            patch_attn = attn_map[1:, 0, :].mean(-1).numpy()
            idx = patch_attn.argmax()
            cluster = patch_clusters[idx]
            path.append(int(cluster))

        return {
            "path": path,
            "patch_clusters": patch_clusters,
            "patch_tracks": patch_tracks,
            "attention_maps": self.attention_maps,
        }
    
    def extract_batch_avg_path(self, image_tensor):
        self.register_hooks()
        with torch.no_grad():
            B = image_tensor.shape[0]
            dummy_prompt_ids = torch.zeros(B, 1, dtype=torch.int).to(self.device)
            output = self.vit_visual(image_tensor.to(self.device, dtype=self.dtype), prompt_ids=dummy_prompt_ids, batch_weight=None)
        self.remove_hooks()

  
        self.remove_hooks()

        batch_size = image_tensor.size(0)
        num_layers = len(self.attention_maps)

        all_path_vectors = []  

        for l in range(num_layers):
            attn_map = self.attention_maps[l]  # shape: [B, N, D]
            patch_tokens = attn_map[:, 1:, :]  # 
            patch_mean = patch_tokens.mean(dim=1)  # [B, D]
            all_path_vectors.append(patch_mean)

        path_tensor = torch.stack(all_path_vectors, dim=1)  # [B, L, D]
        path_avg = path_tensor.mean(dim=0)  # [L, D]

        return path_avg  # 
    
