#!/usr/bin/env python3
"""
External Heads Guided Image Generation for SD3.5
================================================

"""

import os
import json
from datetime import datetime
from typing import Dict, List

import torch
from diffusers import StableDiffusion3Pipeline

# ============================================================================
# CONFIGURATION
# ============================================================================

# SD3.5 model (must match training)
MODEL_ID = "stabilityai/stable-diffusion-3.5-large-turbo"

# Prompt config
PROMPT = "a photo of a doctor"
NEGATIVE_PROMPT = "" #"female, woman"   # e.g., "low quality, blurry"
GUIDANCE_SCALE = 4.5
NUM_INFERENCE_STEPS = 50

# Image resolution for generation (must match training)
IMAGE_RESOLUTION = 512  # you trained with 512 in TrainConfig.resolution

# Seeds / sampling
BASE_SEED = 42
NUM_SEEDS = 1 # number of different seeds

# External heads
# This should point to the final trained checkpoint from train_external_heads_SD35.py
EXTERNAL_HEADS_PATH = "./external_heads_sd35_anti_sexual/external_heads_epoch_30.pt"

# Target layers must match the training config (e.g., range(38) by default)
TARGET_LAYERS: List[int] = list(range(38))

# Optional: restrict to some heads (None = all heads)
TARGET_HEADS = [9, 19, 12, 28]  # e.g., [7, 8] to use only two heads, or None for all

# Coefficients to sweep (strength of external heads)
COEFFICIENT_LIST = [0, 10]

# Output directory
OUTPUT_DIR = "images_out"

# Device / dtype
DEVICE = "cuda"
DTYPE = torch.bfloat16  # match training (bf16)


# ============================================================================
# EXTERNAL HEADS MODULE FOR GENERATION
# ============================================================================

class ExternalHeads(torch.nn.Module):
    """
    Wrapper for trained external heads for SD3.5.

    Assumes the checkpoint has a key "external_heads" of shape:
        [num_layers, num_heads, seq_len, head_dim]

    target_layers_sorted:
        sorted list of global layer indices (same order as used in training).

    We keep the mapping from global layer index -> local index in the tensor.
    """

    def __init__(self, external_tensor: torch.Tensor, target_layers: List[int]):
        super().__init__()

        self.target_layers_sorted = sorted(target_layers)
        self.layer_to_index: Dict[int, int] = {
            l: i for i, l in enumerate(self.target_layers_sorted)
        }

        L_ckpt, H_ckpt, S_ckpt, D_ckpt = external_tensor.shape
        if L_ckpt != len(self.target_layers_sorted):
            raise ValueError(
                f"Checkpoint layers ({L_ckpt}) != len(target_layers) "
                f"({len(self.target_layers_sorted)}). Ensure you use the same "
                f"TARGET_LAYERS as in training."
            )

        self.num_heads = H_ckpt
        self.seq_len = S_ckpt
        self.head_dim = D_ckpt

        # Main parameter
        self.external_heads = torch.nn.Parameter(external_tensor)

    def get_params_for_layer(
        self,
        layer_idx: int,
        seq_len_current: int,
        batch_size: int,
        target_heads=None,
    ) -> torch.Tensor:
        """
        Returns external heads for a specific layer, optionally masking to a subset of heads.

        Output shape: [batch_size, num_heads, seq_len_current, head_dim]
        """
        if layer_idx not in self.layer_to_index:
            # Layer not trained: return zeros
            return torch.zeros(
                batch_size,
                self.num_heads,
                seq_len_current,
                self.head_dim,
                device=self.external_heads.device,
                dtype=self.external_heads.dtype,
            )

        l_local = self.layer_to_index[layer_idx]

        # [H, S, D]
        layer_heads = self.external_heads[l_local]  # [num_heads, seq_len, head_dim]

        if seq_len_current > self.seq_len:
            raise ValueError(
                f"Runtime seq_len ({seq_len_current}) > checkpoint seq_len ({self.seq_len}). "
                f"Resolution or tokenization mismatch between training and generation."
            )

        # Slice to current sequence length: [H, seq_len_current, D]
        layer_heads = layer_heads[:, :seq_len_current, :]

        # Optionally mask to selected heads
        if target_heads is not None:
            # Construct a mask of shape [H] with 1s for selected heads
            mask = torch.zeros(self.num_heads, device=self.external_heads.device)
            for h in target_heads:
                if 0 <= h < self.num_heads:
                    mask[h] = 1.0
            mask = mask.view(self.num_heads, 1, 1)  # [H,1,1]
            layer_heads = layer_heads * mask  # broadcast zeroing

        # Expand to [B, H, N, D]
        layer_heads = layer_heads.unsqueeze(0).expand(batch_size, -1, -1, -1)
        return layer_heads


# ============================================================================
# HOOK INSTALLATION FOR SD3.5 ATTENTION
# ============================================================================

def install_external_heads_hooks_sd35(
    pipe: StableDiffusion3Pipeline,
    external_heads: ExternalHeads,
    target_layers: List[int],
    coefficient: float,
    target_heads=None,
):
    """
    Patch SD3.5 transformer attention blocks to add external heads residuals
    (scaled by 'coefficient') to the attention output.

    Returns:
        baseline_forwards: dict[layer_idx] = original attn.forward
    """
    print(f"\nInstalling external-heads hooks with coefficient={coefficient} ...")
    baseline_forwards = {}
    tr = pipe.transformer

    for layer_idx in target_layers:
        block = tr.transformer_blocks[layer_idx]
        attn = block.attn
        orig_forward = attn.forward

        def make_wrapper(orig_fwd, attn_module, layer_idx_closure):
            def wrapped_forward(hidden_states, *args, **kwargs):
                # Call original attention
                out = orig_fwd(hidden_states, *args, **kwargs)

                if isinstance(out, tuple):
                    attn_hidden = out[0]
                    extra = out[1:]
                else:
                    attn_hidden = out
                    extra = None

                B, N, C = attn_hidden.shape
                H = external_heads.num_heads
                D = external_heads.head_dim

                if H * D != C:
                    raise RuntimeError(
                        f"[Shape mismatch] Layer {layer_idx_closure}: "
                        f"attn_hidden_dim={C}, but external_heads has H={H}, D={D}, H*D={H*D}."
                    )

                # Get [B, H, N, D] for this layer
                ext = external_heads.get_params_for_layer(
                    layer_idx_closure,
                    seq_len_current=N,
                    batch_size=B,
                    target_heads=target_heads,
                )  # [B, H, N, D]

                ext = ext.to(attn_hidden.dtype)  # dtype match

                # Reorder to [B, N, H, D] -> [B, N, H*D]
                ext = ext.permute(0, 2, 1, 3).reshape(B, N, H * D)

                # Add scaled residual
                attn_hidden = attn_hidden + coefficient * ext

                if extra is not None:
                    return (attn_hidden, *extra)
                return attn_hidden

            return wrapped_forward

        attn.forward = make_wrapper(orig_forward, attn, layer_idx)
        baseline_forwards[layer_idx] = orig_forward
        print(f"  ✓ Patched attention forward for SD3.5 layer {layer_idx}")

    print("Done installing hooks.\n")
    return baseline_forwards


def reset_external_heads_hooks_sd35(
    pipe: StableDiffusion3Pipeline,
    target_layers: List[int],
    baseline_forwards: Dict[int, callable],
):
    """
    Restore original attn.forward for all target layers.
    """
    tr = pipe.transformer
    for layer_idx in target_layers:
        block = tr.transformer_blocks[layer_idx]
        attn = block.attn
        if layer_idx in baseline_forwards:
            attn.forward = baseline_forwards[layer_idx]
    print("✓ Restored original attention forwards.\n")


# ============================================================================
# LOAD EXTERNAL HEADS
# ============================================================================

def load_external_heads_sd35(
    checkpoint_path: str,
    target_layers: List[int],
    device: str,
    dtype: torch.dtype,
) -> ExternalHeads:
    """
    Load trained external heads from a checkpoint written by train_external_heads_SD35.py.
    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"External heads checkpoint not found: {checkpoint_path}"
        )

    print(f"\n📦 Loading external heads from: {checkpoint_path}")
    state = torch.load(checkpoint_path, map_location="cpu")

    if "external_heads" not in state:
        # If the entire state_dict was saved from ExternalHeads,
        # it should still contain the "external_heads" parameter.
        possible_keys = ", ".join(state.keys())
        raise KeyError(
            f"'external_heads' key not found in checkpoint. "
            f"Available keys: {possible_keys}"
        )

    ext_tensor = state["external_heads"]  # [L, H, S, D]
    print(
        f"  external_heads tensor shape: {tuple(ext_tensor.shape)} "
        f"(layers, heads, seq_len, head_dim)"
    )

    external_heads = ExternalHeads(ext_tensor, target_layers)
    external_heads = external_heads.to(device=device, dtype=dtype)

    print(
        f"  num_layers={len(external_heads.target_layers_sorted)}, "
        f"num_heads={external_heads.num_heads}, "
        f"seq_len={external_heads.seq_len}, "
        f"head_dim={external_heads.head_dim}"
    )
    print("✓ External heads loaded.\n")
    return external_heads


# ============================================================================
# IMAGE GENERATION
# ============================================================================

def generate_image(
    pipe: StableDiffusion3Pipeline,
    prompt: str,
    negative_prompt: str,
    seed: int,
    num_inference_steps: int,
    guidance_scale: float,
):
    """
    Generate a single image with the given parameters.

    IMPORTANT:
    height/width are forced to IMAGE_RESOLUTION so that the transformer
    sees the same seq_len as during training (1024 tokens).
    """
    generator = torch.Generator(device=pipe.device).manual_seed(seed)

    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        height=IMAGE_RESOLUTION,
        width=IMAGE_RESOLUTION,
    ).images[0]

    return image



# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("SD3.5 EXTERNAL HEADS GUIDED IMAGE GENERATION")
    print("=" * 80)
    print(f"Model: {MODEL_ID}")
    print(f"Prompt: '{PROMPT}'")
    print(f"Negative prompt: '{NEGATIVE_PROMPT}'")
    print(f"Target layers: {TARGET_LAYERS}")
    print(f"Target heads: {TARGET_HEADS if TARGET_HEADS is not None else 'All heads'}")
    print(f"Base seed: {BASE_SEED}")
    print(f"Number of seeds: {NUM_SEEDS}")
    print(f"Coefficients: {COEFFICIENT_LIST}")
    print(f"Output directory: {OUTPUT_DIR}")
    print("=" * 80)

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # ------------------------------------------------------------------------
    # Load SD3.5 model
    # ------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("LOADING SD3.5 PIPELINE")
    print("=" * 80)

    hf_token = (
        os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGINGFACE_HUB_TOKEN")
        or os.environ.get("HUGGINGFACE_TOKEN")
    )
    print("HF_TOKEN visible:", "YES" if hf_token else "NO")

    pipe = StableDiffusion3Pipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        token=hf_token,
    )
    pipe.to(DEVICE)
    print(f"✓ Loaded model: {MODEL_ID} on {DEVICE} with dtype={DTYPE}")
    print()

    # ------------------------------------------------------------------------
    # Load external heads
    # ------------------------------------------------------------------------
    external_heads = load_external_heads_sd35(
        EXTERNAL_HEADS_PATH,
        TARGET_LAYERS,
        device=DEVICE,
        dtype=DTYPE,
    )

    total_images = 0

    # ------------------------------------------------------------------------
    # Generation loop over seeds
    # ------------------------------------------------------------------------
    for seed_idx in range(NUM_SEEDS):
        current_seed = BASE_SEED + seed_idx
        seed_dir = os.path.join(OUTPUT_DIR, f"seed_{current_seed}")
        os.makedirs(seed_dir, exist_ok=True)

        print("\n" + "=" * 80)
        print(f"SEED {current_seed} ({seed_idx + 1}/{NUM_SEEDS})")
        print("=" * 80)
        print(f"Output directory: {seed_dir}")

        # Baseline image (no hooks, coefficient = 0)
        print("\n  Generating baseline (no external heads)...")
        try:
            baseline_image = generate_image(
                pipe,
                PROMPT,
                NEGATIVE_PROMPT,
                current_seed,
                NUM_INFERENCE_STEPS,
                GUIDANCE_SCALE,
            )
            baseline_filename = os.path.join(seed_dir, "baseline_coef_0.0.png")
            baseline_image.save(baseline_filename)
            print(f"  ✓ Saved: {baseline_filename}")
            total_images += 1
        except Exception as e:
            print(f"  ✗ ERROR during baseline generation: {e}")
            import traceback
            traceback.print_exc()
            continue

        # Images with external heads for each coefficient
        print("\n  Generating images with external heads...")

        for coef in COEFFICIENT_LIST:
            baseline_forwards = {}
            try:
                # Install hooks for this coefficient
                baseline_forwards = install_external_heads_hooks_sd35(
                    pipe,
                    external_heads,
                    TARGET_LAYERS,
                    coefficient=coef,
                    target_heads=TARGET_HEADS,
                )

                # Generate image
                image = generate_image(
                    pipe,
                    PROMPT,
                    NEGATIVE_PROMPT,
                    current_seed,
                    NUM_INFERENCE_STEPS,
                    GUIDANCE_SCALE,
                )

                img_name = f"image_coef_{coef:.2f}.png"
                img_path = os.path.join(seed_dir, img_name)
                image.save(img_path)
                print(f"  ✓ Saved: {img_path}")
                total_images += 1

            except Exception as e:
                print(f"  ✗ ERROR: Generating image with coefficient {coef}: {e}")
                import traceback
                traceback.print_exc()

            finally:
                # Always restore original forwards
                try:
                    reset_external_heads_hooks_sd35(
                        pipe,
                        TARGET_LAYERS,
                        baseline_forwards,
                    )
                except Exception as e_reset:
                    print(f"  ✗ ERROR while restoring hooks: {e_reset}")

        print(
            f"\n  ✓ Completed seed {current_seed}: "
            f"{1 + len(COEFFICIENT_LIST)} images (1 baseline + {len(COEFFICIENT_LIST)} with external heads)"
        )

    # ------------------------------------------------------------------------
    # Summary file
    # ------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("WRITING SUMMARY FILE")
    print("=" * 80)

    summary_path = os.path.join(OUTPUT_DIR, "generation_summary_sd35.txt")
    try:
        with open(summary_path, "w") as f:
            f.write("SD3.5 External Heads Guided Generation\n")
            f.write("=" * 80 + "\n\n")
            f.write("Configuration:\n")
            f.write(f"  Model: {MODEL_ID}\n")
            f.write(f"  Prompt: {PROMPT!r}\n")
            f.write(f"  Negative prompt: {NEGATIVE_PROMPT!r}\n")
            f.write(f"  Guidance scale: {GUIDANCE_SCALE}\n")
            f.write(f"  Inference steps: {NUM_INFERENCE_STEPS}\n")
            f.write(f"  Base seed: {BASE_SEED}\n")
            f.write(f"  Num seeds: {NUM_SEEDS}\n")
            f.write(f"  Seeds used: {BASE_SEED} .. {BASE_SEED + NUM_SEEDS - 1}\n")
            f.write(f"  Target layers: {TARGET_LAYERS}\n")
            f.write(
                f"  Target heads: {TARGET_HEADS if TARGET_HEADS is not None else 'All heads'}\n"
            )
            f.write(f"  Coefficients: {COEFFICIENT_LIST}\n")
            f.write(f"  External heads path: {EXTERNAL_HEADS_PATH}\n")
            f.write(f"  Device: {DEVICE}\n")
            f.write(f"  DType: {DTYPE}\n\n")

            f.write("Approach:\n")
            f.write("  1. Load SD3.5 pipeline.\n")
            f.write("  2. Load trained external heads tensor [L, H, S, D].\n")
            f.write("  3. For each coefficient:\n")
            f.write("     - Patch attn.forward in target layers.\n")
            f.write("     - For each call, add coefficient * external_delta\n")
            f.write("       to the attention output in hidden space.\n")
            f.write("     - Restore original forwards after generation.\n\n")

            f.write("Results:\n")
            f.write(
                f"  Total images: {NUM_SEEDS * (1 + len(COEFFICIENT_LIST))}\n"
            )
            f.write(
                f"  Images per seed: 1 baseline + {len(COEFFICIENT_LIST)} with external heads\n\n"
            )

            f.write("Directory structure:\n")
            for seed_idx in range(NUM_SEEDS):
                s = BASE_SEED + seed_idx
                f.write(f"  seed_{s}/\n")
                f.write("    - baseline_coef_0.0.png\n")
                for coef in COEFFICIENT_LIST:
                    f.write(f"    - image_coef_{coef:.2f}.png\n")
                f.write("\n")

            f.write(
                f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
            )

        print(f"✓ Summary written to: {summary_path}")
    except Exception as e:
        print(f"✗ ERROR writing summary file: {e}")

    print("\n" + "=" * 80)
    print("GENERATION COMPLETE")
    print("=" * 80)
    print(f"Output directory: {OUTPUT_DIR}")
    print(f"Total images generated: {total_images}")
    print(f"Seeds: {BASE_SEED} .. {BASE_SEED + NUM_SEEDS - 1}")
    print(f"Coefficients: {COEFFICIENT_LIST}")
    print("✓ All done!")


if __name__ == "__main__":
    main()
