"""
Attention map visualization tool for multimodal video analysis.

This script generates and visualizes attention maps from a trained AVLLM model
to understand which visual regions the model focuses on when processing video
content with textual queries. It supports various output formats including
raw attention maps, heatmaps, and blended visualizations.
"""

import os
import cv2
import argparse
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer
import matplotlib.pyplot as plt

from configs.config import Config
from utils.utils import *
from models import ENCODER, POOLER
from datasets import MultiModalDataset


def read_video(video_path):
    """
    Read video frames from a video file.
    
    Extracts all frames from the input video file along with the frame rate
    information for temporal analysis and visualization purposes.
    
    Args:
        video_path (str): Path to the input video file
        
    Returns:
        tuple: (frames, fps)
            - frames (list): List of video frames as numpy arrays in BGR format
            - fps (float): Frames per second of the video
    """
    frames = []
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames, fps


def main():
    """
    Main function for attention map generation and visualization.
    
    This function orchestrates the entire pipeline:
    1. Load configuration and setup output directories
    2. Initialize video encoder, text encoder, and visual pooler
    3. Process video data through the model
    4. Generate attention maps
    5. Create various visualizations (heatmaps, blended images)
    6. Save results in specified formats
    """
    # Load configuration and parse command line options
    cfg = Config(args.cfg_path, args.options, [])
    
    # Setup output directories based on user preferences
    if args.save_attn_map:
        ATTN_MAP_ROOT = Path(os.path.join(args.save_dir, "raw"))
        ATTN_MAP_ROOT.mkdir(exist_ok=True, parents=True)

    if args.save_heat_map:
        HEATMAP_IMG_ROOT = Path(os.path.join(args.save_dir, "heatmap"))
        HEATMAP_IMG_ROOT.mkdir(exist_ok=True, parents=True)
    
    if args.save_ori_img:
        ORI_IMG_ROOT = Path(os.path.join(args.save_dir, "img"))
        ORI_IMG_ROOT.mkdir(exist_ok=True, parents=True)

    BLEND_IMG_ROOT = Path(os.path.join(args.save_dir, "blend"))
    BLEND_IMG_ROOT.mkdir(exist_ok=True, parents=True)

    ALPHA = args.alpha
    DEVICE = "cuda"

    # Configure model settings for attention map generation
    cfg.config.run.batch_size_eval = 1 
    cfg.config.pooling.visual_align_pooler.output_attention = True
    video_size = cfg.config.llm.video_size

    # Initialize video encoder for feature extraction
    print(f'Building Video Encoder')
    video_encoder = ENCODER.build(
        type=args.video_encoder, 
        configs=cfg.config.video_encoders, 
        device=DEVICE
    )

    # Initialize CLIP text encoder for text-visual alignment
    clip_tokenizer = CLIPTokenizer.from_pretrained(cfg.config.pooling.model_name_or_path)
    clip_text_model = CLIPTextModel.from_pretrained(cfg.config.pooling.model_name_or_path).text_model.to(DEVICE)
    clip_text_model.eval()

    # Initialize visual pooler with attention capability
    print(f'Building Visual Pooler')
    visual_embeds_dim = getattr(cfg.config.video_encoders, args.video_encoder).embeds_dim
    visual_pooler = POOLER.build(
        type=args.visual_pooler,
        text_tokenizer=clip_tokenizer, 
        text_model=clip_text_model, 
        visual_embeds_dim=visual_embeds_dim,
        **dict(getattr(cfg.config.pooling, args.visual_pooler))
    )
    
    # Load pre-trained weights for the visual pooler
    visual_pooler.load_state_dict(torch.load(os.path.join(args.peft_ckpt, "visual_pooler.pt")))
    visual_pooler = visual_pooler.half().to(DEVICE)
    visual_pooler.eval()

    # Setup dataset for video processing
    mm_dataset = MultiModalDataset(
        modality="video",
        data_json_path=args.json_path,
        training=False,
        return_raw_audios=cfg.config.datasets.return_raw_audios,
        seg_len=cfg.config.datasets.seg_len,
        audio_resampling=cfg.config.datasets.audio_resampling,
        audio_sampling_rate=cfg.config.datasets.audio_sampling_rate,
    )
    print(f"Found {len(mm_dataset)} data")

    # Create data loader for batch processing
    loader = get_dataloader(
        mm_dataset, cfg.config.run, is_train=False, use_distributed=False
    )
    
    # Process each video sample
    for i, samples in enumerate(loader):
        # Limit processing to specified number of samples if provided
        if (args.samples is not None) and (i >= args.samples):
            break

        # Prepare input data for model processing
        samples = prepare_sample(samples, cuda_enabled=True)
        output_texts = samples["output_texts"]
        video_inputs = samples["video_data"]  # Shape: [T, C, H, W]
        video_paths = samples["video_path"]

        # Generate attention maps through model forward pass
        with torch.no_grad():
            # Extract video features using the encoder
            video_embeds, _ = video_encoder.encode(video_inputs)
            # Reshape embeddings to match expected format: [batch, frames, patches, features]
            video_embeds = video_embeds.view(video_embeds.size(0), -1, video_size, video_embeds.size(-1))
            # Generate attention maps with shape [N, T, W, H]
            attention_map = visual_pooler(video_embeds, output_texts, output_attention=True)

        # Postprocess attention maps for visualization
        img_list, fps = read_video(video_paths[0])
        H, W = img_list[0].shape[:2]

        # Interpolate attention maps to match original video resolution
        attention_map = F.interpolate(
            attention_map.squeeze(0).unsqueeze(1), 
            size=(H, W), mode='bilinear', align_corners=False
        ).squeeze(1).cpu()

        # Save raw attention map tensors if requested
        if args.save_attn_map:
            save_path = str(ATTN_MAP_ROOT / os.path.basename(video_paths[0])).replace(".mp4", ".pt")
            torch.save(attention_map, save_path)

        # Generate and save visualization results
        for index in range(attention_map.shape[0] // 10):  # Sample every 10th frame
            # Calculate attention statistics for filename
            max_attn_val = round(attention_map[index].max().item(), 4)
            min_attn_val = round(attention_map[index].min().item(), 4)
            
            # Normalize attention map to [0, 255] range
            heatmap = attention_map[index].numpy()
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)
            heatmap = ((heatmap ** 0.5) * 255).astype(np.uint8)
            
            # Apply color mapping to create heatmap visualization
            heatmap_bgr = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
            heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)

            # Get corresponding video frame
            frame_idx = int(index * fps / 2)  # Convert to frame index based on fps
            img = img_list[frame_idx]
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            # Create blended visualization combining heatmap and original frame
            text = output_texts[0][0]["value"]
            blend_img_rgb = cv2.addWeighted(heatmap_rgb, ALPHA, img_rgb, 1-ALPHA, 0)
            
            # Generate descriptive filename with attention statistics
            filename = os.path.basename(video_paths[0]).replace(".mp4", 
                f"-{frame_idx}-{max_attn_val}-{min_attn_val}-{text.replace(' ', '_')}.jpg")

            # Save blended image (always generated)
            blend_img_bgr = cv2.cvtColor(blend_img_rgb, cv2.COLOR_RGB2BGR)
            blend_save_path = str(BLEND_IMG_ROOT / filename)
            cv2.imwrite(blend_save_path, blend_img_bgr)

            # Save standalone heatmap if requested
            if args.save_heat_map:
                heatmap_bgr = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
                heatmap_save_path = str(HEATMAP_IMG_ROOT / filename)
                cv2.imwrite(heatmap_save_path, heatmap_bgr)

            # Save original frame if requested
            if args.save_ori_img:
                img_save_path = str(ORI_IMG_ROOT / filename)
                cv2.imwrite(img_save_path, img)

        print(f"Save {os.path.basename(video_paths[0])} results")
            
    
if __name__ == "__main__":
    """
    Command line interface for attention map visualization.
    
    This script accepts various arguments to control the visualization process,
    including input/output paths, model configurations, and visualization options.
    """
    parser = argparse.ArgumentParser(description='Attention Map Visualization')
    
    # Input data configuration
    parser.add_argument(
        "--json_path", 
        type=str, 
        default='./MUSIC-AVQA/annotations/avllm_json_small/avqa-train-small.json', 
        help='Path to dataset JSON file containing video-text pairs'
    )
    parser.add_argument(
        "--peft_ckpt", 
        type=str, 
        help='Directory path to PEFT checkpoint containing trained model weights'
    )
    parser.add_argument(
        "--save_dir", 
        type=str, 
        help='Root directory path where visualization results will be saved', 
        required=True
    )
    parser.add_argument(
        "--cfg_path", 
        type=str, 
        default='configs/config.yaml', 
        help='Path to model configuration YAML file'
    )
    
    # Model architecture configuration
    parser.add_argument(
        "--video_encoder", 
        type=str, 
        default='internvideo2',
        choices=['internvideo2'], 
        help='Type of video encoder to use for feature extraction',
    )
    parser.add_argument(
        "--visual_pooler", 
        type=str, 
        default='visual_align_pooler',
        help='Type of visual pooler for attention-guided feature aggregation'
    )
    
    # Runtime configuration
    parser.add_argument(
        "--options", 
        nargs="+",
        help="Override configuration settings using key=value pairs. "
             "These will be merged into the config file to modify default settings.",
    )
    parser.add_argument(
        "--samples", 
        type=int,
        help='Limit the number of video samples to process (useful for testing)'
    )
    
    # Output format options
    parser.add_argument(
        "--save_attn_map", 
        action="store_true",
        help='Save raw attention map tensors as .pt files for further analysis'
    )
    parser.add_argument(
        "--save_heat_map", 
        action="store_true",
        help='Save standalone heatmap visualizations as colored images'
    )
    parser.add_argument(
        "--save_ori_img", 
        action="store_true",
        help='Save original video frames corresponding to attention maps'
    )
    
    # Visualization parameters
    parser.add_argument(
        "--alpha", 
        type=float,
        default=0.5,
        help='Blending coefficient for overlaying heatmaps on original images (0.0-1.0). '
             'Higher values emphasize the heatmap, lower values emphasize the original image.'
    )
    
    args = parser.parse_args()
    main()