# This project uses Stable Diffusion, a model developed by Stability AI and released under the CreativeML Open RAIL-M license.

import os
import argparse
import json
from pathlib import Path
from typing import List, Optional, Dict, Any
from tqdm import tqdm
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline
from trace import trace
from utils import set_seed, auto_device, cache_dir


class AttentionHeadExtractor:
    """Helper class to extract attention heads from UNet model"""
    
    def __init__(self):
        self.attention_heads = {}
        self.hooks = []
        
    def register_hooks(self, model):
        """Register forward hooks to capture attention weights"""
        
        def create_hook(name):
            def hook_fn(module, input, output):
                # Extract attention weights from the module
                if hasattr(module, 'processor') and hasattr(module.processor, 'attention_probs'):
                    attention_probs = module.processor.attention_probs
                    if attention_probs is not None:
                        self.attention_heads[name] = attention_probs.detach().cpu()
                elif hasattr(module, 'attention_probs'):
                    attention_probs = module.attention_probs
                    if attention_probs is not None:
                        self.attention_heads[name] = attention_probs.detach().cpu()
            return hook_fn
        
        # Register hooks for all attention modules
        for name, module in model.named_modules():
            if 'attn' in name.lower() and hasattr(module, '__class__'):
                class_name = module.__class__.__name__
                if 'Attention' in class_name or 'CrossAttention' in class_name:
                    hook = module.register_forward_hook(create_hook(name))
                    self.hooks.append(hook)
    
    def clear_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
    
    def clear_attention_data(self):
        """Clear stored attention data"""
        self.attention_heads.clear()


class CustomAttentionProcessor(nn.Module):
    """Custom attention processor to capture attention weights"""
    
    def __init__(self, original_processor):
        super().__init__()
        self.original_processor = original_processor
        self.attention_probs = None
    
    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
        # Store reference to attention module
        self.attn_module = attn
        
        # Call original processor
        result = self.original_processor(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
        
        # Try to capture attention probabilities if available
        if hasattr(attn, 'attention_probs'):
            self.attention_probs = attn.attention_probs
        
        return result


class DAAMDataCreator:
    def __init__(self,
                 prompt: str,
                 concepts: List[str],
                 output_dir: str,
                 num_samples: int = 10,
                 guidance_scale: float = 7.5,
                 num_inference_steps: int = 50,
                 height: int = 512,
                 width: int = 512,
                 low_memory: bool = False,
                 base_seed: Optional[int] = 42,
                 model_path: str = "CompVis/stable-diffusion-v1-4",
                 save_attention_heads: bool = True):
        """
        Initialize DAAM Data Creator for generating multiple images and their attention heatmaps.

        Args:
            prompt: The text prompt to generate images from.
            concepts: List of words to generate heatmaps for.
            output_dir: Directory to save the outputs.
            num_samples: Number of samples to generate.
            guidance_scale: Guidance scale for the diffusion model.
            num_inference_steps: Number of denoising steps.
            height: Height of the generated images.
            width: Width of the generated images.
            low_memory: Whether to use low memory mode.
            base_seed: Starting seed for reproducibility.
            model_path: Path to the pretrained model.
            save_attention_heads: Whether to save attention heads as .pt files.
        """
        self.prompt = prompt
        self.concepts = concepts
        self.output_dir = Path(output_dir)
        self.num_samples = num_samples
        self.guidance_scale = guidance_scale
        self.num_inference_steps = num_inference_steps
        self.height = height
        self.width = width
        self.low_memory = low_memory
        self.base_seed = base_seed
        self.model_path = model_path
        self.save_attention_heads = save_attention_heads

        # Create output directories
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.images_dir = self.output_dir
        self.images_dir.mkdir(exist_ok=True)
        
        # Create attention heads directory
        if self.save_attention_heads:
            self.attention_dir = self.output_dir / "attention_heads"
            self.attention_dir.mkdir(exist_ok=True)

        if len(concepts) > 0:
            self.first_concept_dir = self.images_dir
        else:
            self.first_concept_dir = None

        if len(concepts) > 1:
            self.second_concept_dir = self.images_dir
        else:
            self.second_concept_dir = None

        # Set up metadata
        self.metadata = {
            "prompt": prompt,
            "concepts": concepts,
            "samples": [],
            "save_attention_heads": save_attention_heads,
        }

        # Set up the model
        print(f"Loading model from {model_path}...")
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        self.pipe = auto_device(self.pipe)
        self.pipe.safety_checker = None
        
        # Initialize attention head extractor
        if self.save_attention_heads:
            self.attention_extractor = AttentionHeadExtractor()

    def setup_attention_extraction(self):
        """Setup attention extraction by modifying attention processors"""
        if not self.save_attention_heads:
            return
            
        # Get current attention processors
        original_processors = self.pipe.unet.attn_processors
        
        # Create custom processors that capture attention weights
        custom_processors = {}
        for name, processor in original_processors.items():
            custom_processors[name] = CustomAttentionProcessor(processor)
        
        # Set the custom processors
        self.pipe.unet.set_attn_processor(custom_processors)
        
        # Also register hooks as backup
        self.attention_extractor.register_hooks(self.pipe.unet)

    def extract_attention_heads(self):
        """Extract attention heads from the model after forward pass"""
        if not self.save_attention_heads:
            return {}
            
        attention_data = {}
        
        # Get attention weights from custom processors
        for name, processor in self.pipe.unet.attn_processors.items():
            if hasattr(processor, 'attention_probs') and processor.attention_probs is not None:
                attention_data[name] = processor.attention_probs.detach().cpu()
        
        # Also get from hook-based extractor
        attention_data.update(self.attention_extractor.attention_heads.copy())
        
        return attention_data

    def save_attention_heads(self, attention_data: Dict[str, torch.Tensor], sample_idx: int):
        """Save attention heads to .pt files"""
        if not attention_data:
            return {}
            
        attention_paths = {}
        sample_attention_dir = self.attention_dir / f"sample_{sample_idx:04d}"
        sample_attention_dir.mkdir(exist_ok=True)
        
        for layer_name, attention_tensor in attention_data.items():
            # Clean layer name for filename
            clean_name = layer_name.replace('.', '_').replace('/', '_')
            attention_path = sample_attention_dir / f"{clean_name}.pt"
            
            # Save the tensor
            torch.save({
                'attention_weights': attention_tensor,
                'layer_name': layer_name,
                'shape': attention_tensor.shape,
                'sample_idx': sample_idx
            }, attention_path)
            
            attention_paths[layer_name] = str(attention_path.relative_to(self.output_dir))
        
        return attention_paths

    def save_clean_heatmap(self, heatmap_data: torch.Tensor, save_path: Path, cmap: str = 'jet'):
        """
        Save a clean heatmap image without borders, axes, or titles.

        Args:
            heatmap_data: The heatmap tensor data
            save_path: Path to save the heatmap image
            cmap: Colormap to use for visualization
        """
        # Create figure with no padding or margins
        plt.figure(figsize=(10, 10))
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)

        # Plot heatmap with no axes
        plt.imshow(heatmap_data.cpu().numpy(), cmap=cmap)
        plt.axis('off')

        # Save with tight bounding box and no padding
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=100)
        plt.close()

    def generate_sample(self, sample_idx: int):
        """
        Generate a single sample (image + heatmaps + attention heads).

        Args:
            sample_idx: Index of the sample to generate.

        Returns:
            Dict containing metadata for the generated sample.
        """
        # Use deterministic seed based on sample index
        seed = self.base_seed + sample_idx if self.base_seed is not None else None
        generator = set_seed(seed) if seed is not None else None

        print(f"\nGenerating sample {sample_idx + 1}/{self.num_samples}")
        print(f"Prompt: '{self.prompt}', Seed: {seed}")

        # Setup attention extraction
        self.setup_attention_extraction()

        # Create a tracer to track cross-attention during generation
        with trace(self.pipe, low_memory=self.low_memory) as tracer:
            # Clear previous attention data
            if self.save_attention_heads:
                self.attention_extractor.clear_attention_data()
            
            # Generate the image with the diffusion model
            image = self.pipe(
                prompt=self.prompt,
                height=self.height,
                width=self.width,
                num_inference_steps=self.num_inference_steps,
                guidance_scale=self.guidance_scale,
                generator=generator,
            ).images[0]

            # Save the generated image
            image_path = self.images_dir / f"image_{sample_idx:04d}.jpg"
            image.save(image_path)
            print(f"Saved image to {image_path}")

            # Extract and save attention heads
            attention_paths = {}
            if self.save_attention_heads:
                attention_data = self.extract_attention_heads()
                attention_paths = self.save_attention_heads(attention_data, sample_idx)
                print(f"Saved {len(attention_paths)} attention head files")

            # Compute the global heat map for visualization
            global_heat_map = tracer.compute_global_heat_map()

            # Process each concept and generate heatmap
            concept_data = {}
            for concept_idx, concept in enumerate(self.concepts):
                try:
                    # Compute word-specific heat map
                    word_heat_map = global_heat_map.compute_word_heat_map(concept)

                    # Determine which directory to use based on concept index
                    if concept_idx == 0 and self.first_concept_dir is not None:  # First concept
                        concept_dir = self.first_concept_dir
                    elif concept_idx == 1 and self.second_concept_dir is not None:  # Second concept
                        concept_dir = self.second_concept_dir
                    else:  # Other concepts
                        concept_dir = self.other_viz_dir / concept
                        concept_dir.mkdir(exist_ok=True)

                    # Save the raw heatmap (clean version without borders, axes, or titles)
                    heatmap_data = word_heat_map.expand_as(image)
                    heatmap_path = concept_dir / f"{concept}_{sample_idx:04d}.jpg"

                    # Save clean heatmap image
                    self.save_clean_heatmap(heatmap_data, heatmap_path)

                    concept_data[concept] = {
                        "heatmap_path": str(heatmap_path.relative_to(self.output_dir)),
                    }

                    print(f"Saved heatmaps for concept '{concept}'")

                except ValueError as e:
                    print(f"Could not generate heatmap for concept '{concept}': {e}")
                    concept_data[concept] = {"error": str(e)}

        # Record metadata for this sample
        sample_metadata = {
            "index": sample_idx,
            "seed": seed,
            "image_path": str(image_path.relative_to(self.output_dir)),
            "concepts": concept_data,
            "attention_heads": attention_paths if self.save_attention_heads else {},
        }

        return sample_metadata

    def run(self):
        """
        Generate all samples and save metadata.
        """
        print(f"Starting generation of {self.num_samples} samples for prompt: '{self.prompt}'")
        print(f"Tracking heatmaps for concepts: {', '.join(self.concepts)}")
        if self.save_attention_heads:
            print("Attention heads will be saved as .pt files")
        print(f"Output directory: {self.output_dir}")

        # Generate all samples with progress bar
        for i in tqdm(range(self.num_samples), desc="Generating samples"):
            sample_metadata = self.generate_sample(i)
            self.metadata["samples"].append(sample_metadata)

            # Save metadata after each sample (in case of interruptions)
            with open(self.output_dir / "metadata.json", "w") as f:
                json.dump(self.metadata, f, indent=2)

        # Clean up
        if self.save_attention_heads:
            self.attention_extractor.clear_hooks()

        print(f"\nGeneration completed successfully.")
        print(f"Generated {self.num_samples} samples.")
        print(f"Images saved to: {self.images_dir}")

        # Print locations of concept heatmaps
        if self.first_concept_dir is not None:
            print(f"Heatmaps for '{self.concepts[0]}' saved to: {self.first_concept_dir}")

        if self.second_concept_dir is not None and len(self.concepts) > 1:
            print(f"Heatmaps for '{self.concepts[1]}' saved to: {self.second_concept_dir}")

        if self.save_attention_heads:
            print(f"Attention heads saved to: {self.attention_dir}")

        print(f"Metadata saved to: {self.output_dir / 'metadata.json'}")


def repeat_ntimes(x, n):
    return [item for item in x for i in range(n)]

def parse_args():
    parser = argparse.ArgumentParser(description="Generate images and DAAM heatmaps for a specific prompt")
    parser.add_argument("--prompt", type=str, default="a photo of a woman", help="Text prompt for image generation")
    parser.add_argument("--concepts", nargs="+", default=["woman"], help="Concepts to generate heatmaps for")
    parser.add_argument("--output_dir", type=str, default="./datasets", help="Output directory")
    parser.add_argument("--num_samples", type=int, default=2000, help="Number of samples to generate")
    parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale")
    parser.add_argument("--steps", type=int, default=50, help="Number of inference steps")
    parser.add_argument("--seed", type=int, default=42, help="Base random seed")
    parser.add_argument("--low_memory", action="store_true", help="Use low memory mode")
    parser.add_argument("--height", type=int, default=512, help="Image height")
    parser.add_argument("--width", type=int, default=512, help="Image width")
    parser.add_argument("--model_path", type=str, default="CompVis/stable-diffusion-v1-4",
                        help="Path to pretrained model")
    parser.add_argument("--no_attention_heads", action="store_true", 
                        help="Disable saving attention heads")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    concept_dict = ["female", "male"]
    concept_dict = {c: i for i, c in enumerate(concept_dict)}
    image_prompt = [
        "a photo of a person",
    ]

    input_prompt_and_target_concept = [
        [
            ["a photo of a person", ["female"]],
        ],
    ]

    image_prompt = repeat_ntimes(image_prompt, args.num_samples)
    input_prompt_and_target_concept = repeat_ntimes(input_prompt_and_target_concept, args.num_samples)

    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(input_prompt_and_target_concept, open(args.output_dir + "/labels.json", "w"))
    json.dump(concept_dict, open(args.output_dir + "/concept_dict.json", "w"))

    creator = DAAMDataCreator(
        prompt=args.prompt,
        concepts=args.concepts,
        output_dir=args.output_dir,
        num_samples=args.num_samples,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.steps,
        height=args.height,
        width=args.width,
        low_memory=args.low_memory,
        base_seed=args.seed,
        model_path=args.model_path,
        save_attention_heads=not args.no_attention_heads
    )

    creator.run()
