import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
from torch.utils.data import DataLoader
from collections import defaultdict
import random

import models
from util.datasets import build_dataset
import util.misc as misc

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Change this to your preferred GPU

# Global dictionary to store K and V matrices for each MS_Attention_linear layer
kv_dict = {}


# Hook function to capture K matrix
def get_k_hook(parent_name, parent):
    def hook(module, input, output):
        # Get output shape
        B, C, H, W = output.shape
        N = H * W
        num_heads = 8
        # Check if output is 3D (B, C, N) without time dimension
        output = output.flatten(2)
        output = (
            output.transpose(-1, -2)
            .reshape(B, N, num_heads, C // num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )
        kv_dict.setdefault(parent_name, {})['K'] = output.detach().cpu()

    return hook


# Hook function to capture V matrix
def get_v_hook(parent_name, parent):
    def hook(module, input, output):
        # Get output shape
        B, C, H, W = output.shape
        N = H * W
        num_heads = 8
        # Check if output is 3D (B, C, N) without time dimension
        output = output.flatten(2)
        output = (
            output.transpose(-1, -2)
            .reshape(B, N, num_heads, C // num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )
        kv_dict.setdefault(parent_name, {})['V'] = output.detach().cpu()

    return hook


# Register hooks for all MS_Attention_linear layers
def register_kv_hooks(model):
    from models import MS_Attention_linear  # Import the attention class

    for parent_name, parent_module in model.named_modules():
        if isinstance(parent_module, MS_Attention_linear):
            if hasattr(parent_module, "k_spike"):
                parent_module.k_spike.register_forward_hook(get_k_hook(parent_name, parent_module))
                print(f"Registered K hook for {parent_name}")
            if hasattr(parent_module, "q_spike"):
                parent_module.q_spike.register_forward_hook(get_v_hook(parent_name, parent_module))
                print(f"Registered V hook for {parent_name}")


# IMPROVED: Function to plot attention heatmap overlay on original image
from scipy.ndimage import gaussian_filter

# MODIFIED: Function to plot attention heatmap overlay with noise filtering
def plot_heatmap_overlay(img_tensor, heatmap, save_path, title='Attention Heatmap', filter_threshold=0.7, apply_noise_filter=True):
    """
    Plot the heatmap overlay on the original image with proper alignment, filtering, and noise reduction

    Args:
        img_tensor: Tensor containing the original image [C, H, W]
        heatmap: Numpy array containing the attention heatmap
        save_path: Path to save the visualization
        title: Title for the plot
        filter_threshold: Percentage of small values to filter out (0.7 means filter out bottom 70%)
        apply_noise_filter: Flag to apply Gaussian noise filter on the heatmap
    """
    # Convert image tensor to numpy array and transpose to [H, W, C]
    img = img_tensor.cpu().permute(1, 2, 0).numpy()

    # Make sure the image is in the correct range for visualization
    img = np.clip(img, 0, 1)

    # Get image and heatmap dimensions
    img_h, img_w = img.shape[0], img.shape[1]
    h_patch, w_patch = heatmap.shape

    # Resize heatmap to match image dimensions precisely using bilinear interpolation
    from scipy.ndimage import zoom
    zoom_h = img_h / h_patch
    zoom_w = img_w / w_patch
    heatmap_resized = zoom(heatmap, (zoom_h, zoom_w), order=1)

    # Ensure the resized heatmap has exactly the same dimensions as the image
    if heatmap_resized.shape[0] != img_h or heatmap_resized.shape[1] != img_w:
        heatmap_resized = np.resize(heatmap_resized, (img_h, img_w))

    # Apply noise filter (Gaussian filter) if requested
    if apply_noise_filter:
        heatmap_resized = gaussian_filter(heatmap_resized, sigma=1)

    # Filter out small values (bottom 70%)
    if filter_threshold > 0:
        flat_heatmap = heatmap_resized.flatten()
        threshold_value = np.percentile(flat_heatmap, filter_threshold * 100)
        filtered_heatmap = np.copy(heatmap_resized)
        filtered_heatmap[filtered_heatmap < threshold_value] = 0
    else:
        filtered_heatmap = heatmap_resized

    # Normalize heatmap for better visualization
    heatmap_resized = (heatmap_resized - heatmap_resized.min()) / (heatmap_resized.max() - heatmap_resized.min() + 1e-8)
    filtered_heatmap = (filtered_heatmap - filtered_heatmap.min()) / (
            filtered_heatmap.max() - filtered_heatmap.min() + 1e-8)

    # Create figure with two subplots side by side
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

    # Plot 1: Original image
    ax1.imshow(img)
    ax1.set_title('Original Image')
    ax1.axis('off')

    # Plot 2: Filtered Heatmap only
    hmap = ax2.imshow(filtered_heatmap, cmap='jet')
    ax2.set_title(f'Filtered Attention Heatmap (top {int((1 - filter_threshold) * 100)}%)')
    ax2.axis('off')
    fig.colorbar(hmap, ax=ax2, fraction=0.046, pad=0.04)

    # Plot 3: Overlay visualization with filtered heatmap
    ax3.imshow(img)
    overlay = ax3.imshow(filtered_heatmap, cmap='jet', alpha=0.5)
    ax3.set_title('Filtered Overlay Visualization')
    ax3.axis('off')
    fig.colorbar(overlay, ax=ax3, fraction=0.046, pad=0.04)

    # Set overall title
    plt.suptitle(title, fontsize=16)

    # Save the figure with tight layout
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print(f"Saved visualization to {save_path}")



# Unnormalize function for ImageNet normalization
def unnormalize(img):
    mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(3, 1, 1)
    return img * std + mean


def get_args():
    parser = argparse.ArgumentParser('MAE visualization script', add_help=False)

    # Model parameters
    parser.add_argument('--model', default='spikformer', type=str, metavar='MODEL', help='Name of model to train')
    parser.add_argument('--model_mode', default='ms', type=str, help='Mode of model to train')
    parser.add_argument('--input_size', default=224, type=int, help='images input size')
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')

    # Dataset parameters
    parser.add_argument('--data_path', default='/path/to/imagenet', type=str, help='dataset path')
    parser.add_argument('--nb_classes', default=1000, type=int, help='number of classes')
    parser.add_argument('--batch_size', default=16, type=int, help='Batch size per GPU')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # Output parameters
    parser.add_argument('--output_dir', default='./attention_maps', help='path where to save outputs')

    # Checkpoint parameters
    parser.add_argument('--resume', default='', help='resume from checkpoint')

    # Device parameters
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    # Distributed parameters
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    # Time steps parameter
    parser.add_argument('--time_steps', default=1, type=int)

    # New parameter: max images per class
    parser.add_argument('--max_per_class', default=2, type=int, help='Maximum number of images to show per class')

    # New parameter: max total images to show
    parser.add_argument('--max_images', default=10, type=int, help='Maximum total number of images to show')

    return parser.parse_args()


def exp2_softmax_with_ste(tensor, k=1.0, epsilon=1e-8):
    """
    使用STE实现的2的幂softmax
    """
    # 减去最大值
    max_tensor = torch.max(tensor, dim=-1, keepdim=True)[0]
    tensor = tensor - max_tensor

    # 计算2^(k*tensor)
    tensor = torch.clamp(k * tensor, min=-30.0, max=30.0)
    tensor = torch.pow(2, tensor)

    return tensor


def main():
    args = get_args()

    print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(", ", ",\n"))

    device = torch.device(args.device)

    # Fix the seed for reproducibility
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Create output directory
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Build dataset
    dataset_val = build_dataset(is_train=False, args=args)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False,
    )

    # Create model
    print(f"Creating model: {args.model}")

    # Create the model based on model_mode
    model = models.__dict__[args.model]()
    model.T = args.time_steps

    # Load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print(f"Loading checkpoint from: {args.resume}")
            checkpoint = torch.load(args.resume, map_location='cpu')

            # Handle different checkpoint formats
            if 'model' in checkpoint:
                checkpoint_model = checkpoint['model']
            elif 'state_dict' in checkpoint:
                checkpoint_model = checkpoint['state_dict']
            else:
                checkpoint_model = checkpoint

            # Load model weights
            msg = model.load_state_dict(checkpoint_model, strict=False)
            print(f"Checkpoint loaded with message: {msg}")
        else:
            print(f"No checkpoint found at: {args.resume}")
            return

    model.to(device)
    model.eval()

    # Register hooks
    register_kv_hooks(model)

    # Process multiple batches to ensure diverse class representation
    num_batches = 5  # Process up to 5 batches to find diverse classes
    correctly_predicted_by_class = defaultdict(list)  # Dictionary to store indices by class
    processed_images = 0

    # Process batches until we have enough diverse classes or reach the batch limit
    for batch_idx in range(num_batches):
        try:
            # Get a batch of data
            data_iter = iter(data_loader_val)
            for _ in range(batch_idx):  # Skip to the current batch
                next(data_iter)
            images, labels = next(data_iter)

            # Apply normalization to the images if needed
            normalized_images = images.to(device)
            labels = labels.to(device)

            # Store the original images for visualization
            # If the images are normalized, unnormalize them for visualization
            unnormalized_images = unnormalize(normalized_images.clone())

            # Forward pass to trigger hooks
            with torch.no_grad():
                outputs = model(normalized_images)

            # Get class predictions
            _, predicted = torch.max(outputs, 1)

            # Find correctly predicted images
            correct_mask = (predicted == labels)
            correct_indices = torch.nonzero(correct_mask).squeeze().cpu().tolist()

            # Handle the case where only one image is correctly predicted
            if isinstance(correct_indices, int):
                correct_indices = [correct_indices]

            # Store correctly predicted images by class
            for idx in correct_indices:
                class_id = labels[idx].item()
                # Only add if we haven't reached the max per class yet
                if len(correctly_predicted_by_class[class_id]) < args.max_per_class:
                    # Store batch index, image index, and unnormalized image
                    correctly_predicted_by_class[class_id].append((batch_idx, idx, unnormalized_images[idx].cpu()))
                    processed_images += 1

            print(f"Batch {batch_idx + 1}: Found {len(correct_indices)} correctly predicted images")
            print(f"Total unique classes so far: {len(correctly_predicted_by_class)}")

            # If we have enough images, stop processing batches
            if processed_images >= args.max_images or len(
                    correctly_predicted_by_class) >= args.max_images // args.max_per_class:
                break

        except StopIteration:
            print(f"Reached the end of the dataset after {batch_idx} batches")
            break

    # If no images were correctly predicted, inform the user
    if len(correctly_predicted_by_class) == 0:
        print("No images were correctly predicted. Cannot generate visualizations.")
        return

    print(f"Found {len(correctly_predicted_by_class)} different classes with correctly predicted images")

    # Ensure we have a balanced representation across classes
    selected_images = []
    max_per_class = min(args.max_per_class, max(len(images) for images in correctly_predicted_by_class.values()))

    # First, include at least one image from each class
    for class_id, images in correctly_predicted_by_class.items():
        if images:  # If the class has any images
            selected_images.append((class_id, images[0]))  # Add the first image of this class

    # Then, fill in with additional images from each class
    remaining_slots = args.max_images - len(selected_images)
    if remaining_slots > 0:
        # Get classes that have more than one image
        classes_with_multiple = [class_id for class_id, images in correctly_predicted_by_class.items()
                                 if len(images) > 1]

        # Shuffle to ensure randomness
        random.shuffle(classes_with_multiple)

        # Add additional images from each class in a round-robin fashion
        added_count = 0
        while added_count < remaining_slots and classes_with_multiple:
            for class_id in classes_with_multiple[:]:
                if added_count >= remaining_slots:
                    break

                images = correctly_predicted_by_class[class_id]
                # Check if we already used all images from this class
                if len([img for c, img in selected_images if c == class_id]) < len(images):
                    # Find the next unused image
                    used_indices = [img[1] for c, img in selected_images if c == class_id]
                    for i, img_data in enumerate(images):
                        if img_data[1] not in used_indices:
                            selected_images.append((class_id, img_data))
                            added_count += 1
                            break
                else:
                    # Remove this class as it has no more unused images
                    classes_with_multiple.remove(class_id)

    print(
        f"Selected {len(selected_images)} images from {len(set(class_id for class_id, _ in selected_images))} different classes")

    # Now process the selected images for visualization
    processed_count = 0

    # Process collected data and visualize attention maps
    for mod_name, kv in kv_dict.items():
        if 'K' in kv and 'V' in kv:
            # Create a directory for this layer
            layer_dir = os.path.join(args.output_dir, mod_name.replace('.', '_'))
            os.makedirs(layer_dir, exist_ok=True)

            # Process each selected image
            for processed_count, (class_id, (batch_idx, img_idx, original_img)) in enumerate(selected_images):
                # Get the K and V matrices for the current batch
                K = kv['K']
                V = kv['V']

                # Calculate attention for this image
                # Note: The actual K and V might be from different batches, so we need to reload them
                if 'current_batch' not in locals() or current_batch != batch_idx:
                    # This is a new batch, need to reload the model and get the K/V values
                    current_batch = batch_idx
                    data_iter = iter(data_loader_val)
                    for _ in range(batch_idx):  # Skip to the current batch
                        next(data_iter)
                    images, labels = next(data_iter)
                    normalized_images = images.to(device)

                    # Forward pass to trigger hooks
                    with torch.no_grad():
                        outputs = model(normalized_images)

                # Get the K and V for this batch
                K = kv['K']
                V = kv['V']

                # Get dimensions - without T dimension
                # Shape should be (B, num_heads, N, head_dim)
                B, num_heads, N, head_dim = K.shape

                # Calculate attention scores (K @ V^T)
                attn = (V @ K.transpose(-2, -1))

                # Apply softmax to attention scores
                attn_softmax = 1.-exp2_softmax_with_ste(attn)

                # Get attention map for this image across all heads
                attn_map = attn_softmax[img_idx]  # shape: (num_heads, N, N)

                # Calculate patch importance (how much attention each patch receives)
                # Average across all positions (last dimension)
                patch_scores = attn_map.mean(dim=-1)  # shape: (num_heads, N)

                # Average across heads
                patch_scores = patch_scores.mean(dim=0)  # shape: (N)

                # Reshape to 2D grid - get the spatial dimensions
                H_patch = W_patch = int(np.sqrt(N))
                patch_heatmap = patch_scores.reshape(H_patch, W_patch).numpy()

                # Create a descriptive title
                plot_title = f"Layer: {mod_name} - Class: {class_id}"

                # Generate save path
                save_path = os.path.join(layer_dir, f"class_{class_id}_batch_{batch_idx}_image_{img_idx}.png")

                # Use the visualization function
                plot_heatmap_overlay(original_img, patch_heatmap, save_path, plot_title)

                print(f"Processed image {processed_count + 1}/{len(selected_images)} (Class {class_id})")

    print(f"Total processed: {processed_count + 1} images")


if __name__ == '__main__':
    main()