import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2  # For resizing the attention map
from surrogates.FeatureExtractors.clip import (
    ClipFeatureExtractor,
    HFClipFeatureExtractor,
    OpenClipFeatureExtractor
)  # Import necessary classes
from surrogates.FeatureExtractors.dino import Dinov2FeatureExtractor
from surrogates.FeatureExtractors.blip2 import BlipFeatureExtractor
import math
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from dataclasses import dataclass, field
from typing import List, Tuple
from surrogates.FeatureExtractors.clip import MODEL_DICT
from torch import nn

# --- Imports for Type Checking in get_attention_map ---
from open_clip.transformer import VisionTransformer as OpenClipVisionTransformer
from transformers import CLIPVisionModel
# Corrected import paths based on typical transformers structure
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model 
from transformers.models.blip_2.modeling_blip_2 import Blip2VisionModel


# --- Hydra Configuration Dataclass ---
@dataclass
class VisConfig:
    # --- Input Image Paths ---
    original_dir: str = "resources/images/bigscale/nips17"  # Directory containing original images
    target_dir: str = "resources/images/target_images/1"    # Directory containing target images
    adversarial_dir: str = "LAT/img/1fffc5d9c66225a8c1fc0224ec80a8b5/nips17"  # Directory containing adversarial images
    image_idx: str = "1"  

    # --- Model Configuration ---
    model_name: str = (
        "dino_large"  # can be "clip_b16", "clip_l14", "clip_b32", "clip_laion_b32", "clip_laion_g14"
    )
    layer_index: int = -1  # Layer to extract attention from (-1 for last)

    # --- Output Configuration ---
    output_dir_base: str = (
        "resources/attention_maps"  # Base directory to save visualizations
    )
    device: str = "cuda:5"  # Device to run model and tensors on (e.g., 'cpu' or 'cuda:0')
    

# --- Model & Image Loading (Adapted for Config) ---
def load_model_and_image(
    model_name: str, image_path: str
) -> Tuple[nn.Module, torch.Tensor, Image.Image, int]:
    """Loads the specified model and preprocesses the image.
    
    Returns the core vision model ready for attention extraction, 
    the processed image tensor, the PIL image used for processing, and the input size.
    """
    feature_extractor = None
    core_vision_model = None

    if 'clip' in model_name:
        if model_name == 'clip_laion_b16':
            # User intends to use OpenClipFeatureExtractor for this model
            feature_extractor = OpenClipFeatureExtractor(model_name)
            if not hasattr(feature_extractor.model, 'visual'):
                raise TypeError(
                    f"OpenClip model {model_name} from OpenClipFeatureExtractor "
                    "is missing the 'visual' attribute, which is expected to be the vision transformer."
                )
            core_vision_model = feature_extractor.model.visual  # This is open_clip.model.VisionTransformer
        else:
            # For other CLIP models, assume HuggingFace implementation
            feature_extractor = HFClipFeatureExtractor(model_name)
            # HFClipFeatureExtractor's self.model is already the vision_model (e.g., CLIPVisionModel)
            core_vision_model = feature_extractor.model
    elif 'blip' in model_name:
        feature_extractor = BlipFeatureExtractor(model_name)
        # The BlipFeatureExtractor is expected to set its .model attribute to the vision sub-module.
        # The script previously did: feature_extractor.model = feature_extractor.model.vision_model
        # This should ideally be handled within BlipFeatureExtractor's init or a dedicated method.
        # For now, we ensure it happens here if not already done by the class.
        if hasattr(feature_extractor.model, 'vision_model') and not isinstance(feature_extractor.model, nn.Sequential) and 'Blip2VisionModel' not in str(type(feature_extractor.model)): # avoid issues if model is already vision_model
            feature_extractor.model = feature_extractor.model.vision_model
        core_vision_model = feature_extractor.model
    elif 'dino' in model_name:
        feature_extractor = Dinov2FeatureExtractor(model_name)
        core_vision_model = feature_extractor.model  # Dinov2Model itself
    else:
        raise NotImplementedError(
            f"Model loading for attention visualization is not supported for model name {model_name}. "
            "Please use a supported model type (CLIP, BLIP, DINO)."
        )

    if feature_extractor is None or core_vision_model is None:
        raise RuntimeError(f"Could not initialize feature extractor or core vision model for {model_name}.")

    processor_transforms = feature_extractor.normalizer.transforms

    # Load the original image
    original_image = Image.open(image_path).convert("RGB")

    # --- Manual Preprocessing --- replicating processor logic
    img_for_processing = original_image.copy()
    input_size = 224  # Default, will be updated
    pil_transforms = []
    tensor_transforms = []
    found_to_tensor = False

    for transform in processor_transforms:
        if isinstance(transform, (T.Resize, T.CenterCrop)):
            if found_to_tensor:
                raise TypeError(
                    f"PIL transform {type(transform)} found after ToTensor in processor."
                )
            pil_transforms.append(transform)
            # Determine input size from Resize/CenterCrop
            size = transform.size
            input_size = size if isinstance(size, int) else size[0]
        elif isinstance(transform, T.ToTensor):
            found_to_tensor = True
        elif isinstance(transform, T.Normalize):
            tensor_transforms.append(transform)
        else:
            # Handle other potential transforms (assumes they are PIL-based if before ToTensor)
            if not found_to_tensor:
                pil_transforms.append(transform)
            else:
                raise TypeError(
                    f"Unhandled tensor transform {type(transform)} found after ToTensor."
                )

    # Apply PIL transforms
    if pil_transforms:
        pil_processor = T.Compose(pil_transforms)
        img_for_processing = pil_processor(img_for_processing)

    # Apply ToTensor
    tensor_converter = T.ToTensor()
    processed_tensor = tensor_converter(img_for_processing)

    # Apply Tensor transforms (like Normalize)
    if tensor_transforms:
        tensor_processor = T.Compose(tensor_transforms)
        processed_tensor = tensor_processor(processed_tensor)
    # --- End Manual Preprocessing ---

    processed_image = processed_tensor.unsqueeze(0)  # Add batch dimension

    return core_vision_model, processed_image, img_for_processing, input_size


# --- Attention Extraction (Adapted for OpenCLIP) ---
def get_attention_map(
    model: nn.Module, processed_image: torch.Tensor, layer: int = -1
) -> np.ndarray:
    """Extracts and processes the attention map from the model for a specific layer.

    Args:
        model: The vision model (e.g., CLIPVisionModel, Dinov2Model, OpenClipVisionTransformer).
        processed_image: The preprocessed image tensor.
        layer: The index of the layer from which to extract attention. 
               For OpenCLIP, this refers to the index in `model.resblocks`.

    Returns:
        A 2D numpy array representing the attention map of the CLS token to image patches.
    """
    model.eval() # Ensure the model is in evaluation mode
    
    # This variable will store the attention tensor for the target layer
    # Expected shape: (batch_size, num_heads, num_tokens, num_tokens)
    target_layer_attention_matrices = None

    with torch.no_grad(): # No gradients needed for attention extraction
        if isinstance(model, (CLIPVisionModel, Dinov2Model, Blip2VisionModel)):
            # Standard HuggingFace model attention extraction
            outputs = model(
                pixel_values=processed_image,
                output_attentions=True,
                output_hidden_states=False,
                return_dict=True,
            )
            if not hasattr(outputs, 'attentions') or not outputs.attentions:
                raise ValueError(
                    f"Model {type(model)} did not return attentions. "
                    "Check if output_attentions=True is supported and effective."
                )
            # outputs.attentions is a tuple of tensors, one for each layer from the HF model
            target_layer_attention_matrices = outputs.attentions[layer]

        elif isinstance(model, OpenClipVisionTransformer):
            # Custom attention extraction for OpenCLIP using hooks
            captured_attentions_list = []
            def hook_fn(module, _input, output_tuple):
                # For OpenCLIP's ResidualAttentionBlock, output_tuple is (block_output_features, attention_weights)
                # attention_weights shape: (batch_size, num_heads, num_patches + 1, num_patches + 1)
                captured_attentions_list.append(output_tuple[1].detach().cpu())

            # --- Accessing the sequence of transformer blocks ---
            # For open_clip.transformer.VisionTransformer, the blocks are typically in model.transformer.resblocks
            if not (hasattr(model, 'transformer') and hasattr(model.transformer, 'resblocks')):
                raise TypeError(
                    f"OpenCLIP model {type(model)} is missing the expected 'model.transformer.resblocks' path. "
                    f"Cannot apply hooks for attention extraction. "
                    f"Has 'model.transformer': {hasattr(model, 'transformer')}."
                )

            transformer_blocks_container = model.transformer.resblocks

            if not isinstance(transformer_blocks_container, (nn.Sequential, nn.ModuleList)):
                # If it's not Sequential or ModuleList, let's see what it is.
                # It still needs to be iterable and contain the blocks.
                actual_type = type(transformer_blocks_container)
                # Check if it's at least list-like and contains modules.
                is_list_like_of_modules = False
                if hasattr(transformer_blocks_container, '__getitem__') and hasattr(transformer_blocks_container, '__len__'):
                    if len(transformer_blocks_container) > 0:
                        if all(isinstance(m, nn.Module) for m in transformer_blocks_container):
                            is_list_like_of_modules = True
                            print(f"Info: 'model.transformer.resblocks' is of type {actual_type}, "
                                  "but appears to be a list-like container of modules. Proceeding with caution.")
                        else:
                            print(f"Warning: 'model.transformer.resblocks' is of type {actual_type} and list-like, "
                                  "but its elements are not all nn.Modules.")
                    else: # it's an empty list-like container
                         is_list_like_of_modules = True # Technically, can proceed but num_blocks will be 0
                         print(f"Info: 'model.transformer.resblocks' is of type {actual_type} and is an empty list-like container.")

                if not is_list_like_of_modules:
                    raise TypeError(
                        f"OpenCLIP's 'model.transformer.resblocks' is not an nn.Sequential or nn.ModuleList. "
                        f"Actual type: {actual_type}. "
                        f"Cannot reliably access transformer blocks for hooking. "
                        f"Please ensure 'model.transformer.resblocks' is a sequence of ResidualAttentionBlocks."
                    )
            
            transformer_blocks_sequence = transformer_blocks_container # Use the container directly
            
            if len(transformer_blocks_sequence) == 0:
                 raise ValueError("Identified 'model.transformer.resblocks' but it contains no blocks.")

            num_blocks = len(transformer_blocks_sequence)
            # Resolve layer index (e.g., -1 to last_layer_idx, 0 to first_layer_idx)
            actual_layer_idx = layer if layer >= 0 else num_blocks + layer

            if not (0 <= actual_layer_idx < num_blocks):
                raise ValueError(
                    f"Layer index {layer} (resolved to {actual_layer_idx}) is out of bounds "
                    f"for OpenCLIP model with {num_blocks} ResidualAttentionBlocks. "
                    f"Valid range is 0 to {num_blocks - 1} (or -1 to -{num_blocks})."
                )
            
            # Get the specific ResidualAttentionBlock
            target_res_block = transformer_blocks_sequence[actual_layer_idx]

            # The attention mechanism is an attribute of the ResidualAttentionBlock (e.g., 'attn')
            if not hasattr(target_res_block, 'attn'):
                raise AttributeError(
                    f"The ResidualAttentionBlock at layer {actual_layer_idx} (original index {layer}) "
                    f"of type {type(target_res_block)} does not have an 'attn' attribute. "
                    "Cannot hook the internal attention mechanism."
                )
            
            module_to_hook = target_res_block.attn # Hook the self.attn module inside the block
            handle = module_to_hook.register_forward_hook(hook_fn)
            
            _ = model(processed_image) # Forward pass to trigger the hook
            
            handle.remove() # Crucial to remove the hook after use

            if not captured_attentions_list:
                raise RuntimeError(
                    f"Failed to capture attention from OpenCLIP model {type(model)} "
                    f"at layer {actual_layer_idx} (original index {layer}). "
                    "Hook did not run or did not capture data as expected."
                )
            if len(captured_attentions_list) > 1:
                 print(
                     f"Warning: Hook for OpenCLIP layer {actual_layer_idx} captured "
                     f"{len(captured_attentions_list)} items, expected 1. Using the first captured tensor."
                 )

            target_layer_attention_matrices = captured_attentions_list[0] # .cpu() already done by the hook

        else:
            raise NotImplementedError(
                f"Attention map extraction for model type {type(model)} is not currently supported by this script."
            )

    if target_layer_attention_matrices is None:
        raise ValueError("Failed to obtain attention matrices for the target layer. This should not happen if model type is supported.")

    # Ensure tensor is on CPU before further processing if not already (OpenCLIP hook handles this)
    if target_layer_attention_matrices.device != torch.device('cpu'):
        target_layer_attention_matrices = target_layer_attention_matrices.cpu()

    # --- Common Attention Processing --- 
    # At this point, target_layer_attention_matrices should be the attention from the specified layer,
    # with shape (batch_size, num_heads, num_tokens, num_tokens).

    if target_layer_attention_matrices.ndim != 4:
        raise ValueError(
            f"Expected 4D attention tensor (batch, heads, seq_len, seq_len), "
            f"but got {target_layer_attention_matrices.ndim}D shape: {target_layer_attention_matrices.shape}. "
            f"Model type: {type(model)}, Layer: {layer}."
        )

    # Average attention across heads and select the first item in the batch (batch_size is 1)
    # Shape of avg_attention_heads: (num_tokens, num_tokens)
    avg_attention_heads = target_layer_attention_matrices.mean(dim=1)[0]
    
    # The CLS token is typically the first token (index 0).
    # We are interested in its attention to all other patch tokens (from index 1 onwards).
    # Shape of cls_attention_to_patches: (num_patches)
    cls_attention_to_patches = avg_attention_heads[0, 1:]
    
    num_patches = cls_attention_to_patches.shape[0]
    if num_patches == 0:
        raise ValueError(
            f"Number of patches is 0 after processing attention for model {type(model)}, layer {layer}. "
            "This might indicate an issue with CLS token attention slicing or an unexpected attention format."
        )
        
    patch_dim_float = math.sqrt(num_patches)
    patch_dim_int = int(patch_dim_float)

    if not math.isclose(patch_dim_float, patch_dim_int) or patch_dim_int == 0:
        # Check if patch_dim_int * patch_dim_int == num_patches more robustly for float issues
        if not (patch_dim_int > 0 and patch_dim_int * patch_dim_int == num_patches):
            raise ValueError(
                f"Cannot reshape attention map for {num_patches} patches into a square grid (sqrt={patch_dim_float:.2f}). "
                f"Model {type(model)} (layer {layer}) might have a non-square patch grid, or attention processing is incorrect."
            )

    # Reshape the 1D patch attentions into a 2D map
    attention_map_2d = cls_attention_to_patches.reshape(patch_dim_int, patch_dim_int)
    
    return attention_map_2d.numpy() # Tensor is already on CPU


# --- Visualization (Adapted for Config) ---
def visualize_attention(
    pil_image: Image.Image,  # Now receives the already-processed image
    attention_map_2d: np.ndarray,
    input_size: int,
    output_path_prefix: str,
):
    """Visualizes the attention map and overlays it on the image."""
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path_prefix), exist_ok=True)

    # Remove the resize operation since image is already processed to input_size
    image_resized_pil = pil_image  # Directly use the processed image

    # Normalize the attention map
    min_val, max_val = np.min(attention_map_2d), np.max(attention_map_2d)
    epsilon = 1e-8
    attention_map_normalized = (attention_map_2d - min_val) / (
        max_val - min_val + epsilon
    )

    # Save the raw attention map (grayscale)
    raw_map_path = f"{output_path_prefix}_raw.png"
    plt.imsave(raw_map_path, attention_map_normalized, cmap="gray")

    # Create heatmap
    heatmap = plt.get_cmap("hot")(attention_map_normalized)[:, :, :3]
    heatmap = (heatmap * 255).astype(np.uint8)
    
    # Convert resized PIL image to numpy for blending
    image_resized_np = np.array(image_resized_pil)
    
    # Resize heatmap to match the image dimensions
    image_height, image_width = image_resized_np.shape[:2]
    heatmap_height, heatmap_width = heatmap.shape[:2]
    
    # Debug information
    print(f"Image shape: {image_resized_np.shape}, Heatmap shape: {heatmap.shape}")
    
    # Resize heatmap to match the image dimensions
    heatmap_resized = cv2.resize(heatmap, (image_width, image_height), interpolation=cv2.INTER_LINEAR)
    
    # Ensure both arrays have the same number of channels (RGB format)
    if heatmap_resized.shape[-1] != image_resized_np.shape[-1]:
        # Convert grayscale to RGB if needed
        if len(heatmap_resized.shape) == 2:
            heatmap_resized = cv2.cvtColor(heatmap_resized, cv2.COLOR_GRAY2RGB)
        elif heatmap_resized.shape[-1] == 4:  # RGBA
            heatmap_resized = heatmap_resized[:, :, :3]  # Take only RGB channels
            
        if len(image_resized_np.shape) == 2:
            image_resized_np = cv2.cvtColor(image_resized_np, cv2.COLOR_GRAY2RGB)
        elif image_resized_np.shape[-1] == 4:  # RGBA
            image_resized_np = image_resized_np[:, :, :3]  # Take only RGB channels
    
    # Overlay heatmap
    alpha = 0.5
    overlay_img = cv2.addWeighted(heatmap_resized, alpha, image_resized_np, 1 - alpha, 0)

    # Save the overlay image
    overlay_path = f"{output_path_prefix}_overlay.png"
    Image.fromarray(overlay_img).save(overlay_path)

    print(f"Saved attention maps:")
    print(f"  Raw: {raw_map_path}")
    print(f"  Overlay: {overlay_path}")


def find_image_path(directory: str, image_idx: str) -> str | None:
    """Tries to find an image file based on index, checking .jpg and .png."""

    # Try .jpg first
    path_jpg = os.path.join(directory, f"{image_idx}.jpg")
    if os.path.exists(path_jpg):
        return path_jpg

    # Try .png if .jpg not found
    path_png = os.path.join(directory, f"{image_idx}.png")
    if os.path.exists(path_png):
        return path_png

    # Try .jpeg as well
    path_jpeg = os.path.join(directory, f"{image_idx}.jpeg")
    if os.path.exists(path_jpeg):
        return path_jpeg

    return None # Not found


# --- Main Execution Logic with Hydra ---
@hydra.main(
    version_base=None, config_name="vis_config"
)  # config_path=None assumed for embedded
def run_visualization(cfg: VisConfig):  # Use VisConfig type hint
    print("--- Configuration ---")
    print(OmegaConf.to_yaml(cfg))  # Print the final configuration
    print("---------------------")

    # Use VisConfig directly, Hydra handles instantiation
    config = cfg
    # Prepare torch device
    device = torch.device(config.device)
    print(f"Using device: {device}")

    # Process each type of image with specific directories
    paths = [
        ("original", config.original_dir),
        ("target", config.target_dir),
        ("adversarial", config.adversarial_dir),
    ]
    for label, base_dir in paths:
        # Try finding the image with common extensions
        found_image_path = find_image_path(base_dir, config.image_idx)

        if not found_image_path:
            print(f"Warning: Image file like '{config.image_idx}.jpg' or '{config.image_idx}.png' not found in {base_dir}. Skipping {label}.")
            continue

        print(f"\nProcessing {label} image: {found_image_path}")

        try:
            print(f"Loading model: {config.model_name}")
            model, processed_image, img_for_processing, input_size = load_model_and_image(
                config.model_name, found_image_path
            )
            # Move model and input tensor to the specified device
            model = model.to(device)
            processed_image = processed_image.to(device)

            print(f"Extracting attention map from layer {config.layer_index}...")
            attention_map = get_attention_map(
                model, processed_image, layer=config.layer_index
            )

            print("Visualizing and saving attention map...")
            # Construct output path according to: output_dir_base/label/image_idx/model_name_layerX
            # 'label' is like 'original', 'target', 'adversarial'
            # 'image_idx' is the image identifier
            # 'model_name_layerX' will be the filename prefix
            
            # New directory structure: output_dir_base/label/image_idx/
            specific_output_dir = os.path.join(config.output_dir_base, label, config.image_idx)
            
            # Filename prefix will be: model_name_layerX
            output_filename_prefix = f"{config.model_name}_layer{config.layer_index}"
            
            output_prefix = os.path.join(specific_output_dir, output_filename_prefix)

            visualize_attention(
                img_for_processing, attention_map, input_size, output_prefix
            )

            print(f"Finished processing {label}/{config.image_idx} (found as {os.path.basename(found_image_path)})")

        except FileNotFoundError:
            print(
                f"Error: Image file not found during processing {found_image_path}. This shouldn't happen after the initial check."
            )
        except KeyError:
            print(
                f"Error: Model name '{config.model_name}' not found in clip.py's MODEL_DICT."
            )
        except NotImplementedError as e:
            print(f"Error: {e}")
        except ValueError as e:
            print(f"Error processing {found_image_path}: {e}")
        except Exception as e:
            print(f"An unexpected error occurred processing {found_image_path}: {e}")
            import traceback

            traceback.print_exc()

    print("\nVisualization script finished.")


# --- Entry Point ---
if __name__ == "__main__":
    # Manually register the config schema if not using config files
    cs = hydra.core.config_store.ConfigStore.instance()
    cs.store(name="vis_config", node=VisConfig)
    run_visualization()
