#!/usr/bin/env python3
"""
CORRECTED SD3.5 External Heads Training

"""

import os
import json
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusers import StableDiffusion3Pipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from utils_data import TrainingDataset


# ============================================================
# Config
# ============================================================

@dataclass
class TrainConfig:
    model_id: str = "stabilityai/stable-diffusion-3.5-large-turbo"

    resolution: int = 512
    batch_size: int = 4
    num_epochs: int = 30
    learning_rate: float = 1e-4
    weight_decay: float = 0.0

    mixed_precision: str = "bf16"
    output_dir: str = "external_heads_sd35_anti_sexual"

    seed: int = 42

    # TRAINING TEXTS
    prompt_person: str = "a photo of a person"
    prompt_woman: str = "a photo of a woman"

    # Save every N epochs
    save_every: int = 10

    train_data_dir: str = "datasets/person_sd"


cfg = TrainConfig()


# ============================================================
# External Heads Module
# ============================================================

class ExternalHeads(nn.Module):
    """
    Stores per-layer external heads: [num_layers, num_heads, seq_len, head_dim].
    """
    def __init__(self, num_layers, num_heads, seq_len, head_dim):
        super().__init__()
        self.external_heads = nn.Parameter(
            torch.zeros(num_layers, num_heads, seq_len, head_dim)
        )
        print(
            f"Initialized ExternalHeads: {num_layers} layers, {num_heads} heads, "
            f"seq_len={seq_len}, head_dim={head_dim}"
        )

    def forward(self):
        return self.external_heads


# ============================================================
# Hook for SD3.5 Attention
# ============================================================

def install_external_heads_hooks(
    pipe, external_heads, target_layers, device, dtype
):
    """
    Install additive residual external heads into SD3.5 transformer blocks.
    """
    print("\nInstalling external heads hooks (CORRECTED DIRECTION)...")

    baseline_forwards = {}

    for layer_idx in target_layers:
        block = pipe.transformer.transformer_blocks[layer_idx]
        attn = block.attn
        orig_forward = attn.forward
        baseline_forwards[layer_idx] = orig_forward

        def make_wrapper(orig_fwd, attn_module, layer_index):

            def wrapped_forward(hidden_states, *args, **kwargs):
                out = orig_fwd(hidden_states, *args, **kwargs)

                # If external heads are disabled, just return baseline attention output
                if not getattr(pipe.transformer, "use_external_heads", True):
                    return out

                # Unpack attention output
                if isinstance(out, tuple):
                    attn_hidden, *rest = out
                else:
                    attn_hidden = out
                    rest = None

                B, N, C = attn_hidden.shape

                # Extract trained heads for this layer
                layer_heads = external_heads[layer_index]   # [H, S, D]
                H, S, D = layer_heads.shape

                if S != N:
                    raise RuntimeError(
                        f"Seq length mismatch: training S={S}, generation N={N} "
                        f"in layer {layer_index}"
                    )

                # Expand: [H,S,D] → [B,H,S,D]
                delta = layer_heads.unsqueeze(0).expand(B, -1, -1, -1)
                delta = delta.to(device=device, dtype=attn_hidden.dtype)

                # Reorder to [B,N,C]
                delta = delta.permute(0, 2, 1, 3).reshape(B, N, H * D)

                attn_hidden = attn_hidden + delta

                if rest:
                    return (attn_hidden, *rest)
                return attn_hidden

            return wrapped_forward

        attn.forward = make_wrapper(orig_forward, attn, layer_idx)
        print(f"  ✓ Patched layer {layer_idx}")

    print("Hooks installed.\n")
    return baseline_forwards


# ============================================================
# Corrected Training Step (WOMAN − PERSON)
# ============================================================

def pipe_encode_prompts(pipe, prompts, device, max_sequence_length=256):
    """Encode SD3.5 prompts using the full multi-encoder stack (like previous trainer)."""
    (
        prompt_embeds,
        _,
        pooled_prompt_embeds,
        _,
    ) = pipe.encode_prompt(
        prompt=prompts,
        prompt_2=prompts,
        prompt_3=prompts,
        negative_prompt=None,
        negative_prompt_2=None,
        negative_prompt_3=None,
        device=device,
        num_images_per_prompt=1,
        max_sequence_length=max_sequence_length,
    )
    return prompt_embeds, pooled_prompt_embeds


def training_step(
    pipe,
    noise_scheduler,
    pixel_values,
    prompts_person,
    prompts_woman,
    device,
):
    """
    Single training step for SD3.5 external heads.

    CORRECTED DIRECTION:
      - We want (person + heads) ≈ woman
      - g = woman − person
    So we:
      1) Run a HEADS-OFF pass with the target prompt "woman" to get a reference noise.
      2) Run a HEADS-ON pass with the source prompt "person".
      3) Minimize MSE(noise_pred(person + heads), noise_ref(woman, no heads)).
    """
    transformer = pipe.transformer
    vae = pipe.vae

    B = pixel_values.shape[0]

    # Encode images → latents
    with torch.no_grad():
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

    # Sample timesteps & add noise (FlowMatch-friendly, like previous working version)
    noise = torch.randn_like(latents)
    if isinstance(noise_scheduler, FlowMatchEulerDiscreteScheduler):
        # Use precomputed scheduler timesteps if available
        all_timesteps = noise_scheduler.timesteps.to(device)
        idx = torch.randint(0, all_timesteps.shape[0], (B,), device=device)
        timesteps = all_timesteps[idx]
    else:
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (B,), device=device
        ).long()

    # FlowMatchEulerDiscreteScheduler has no add_noise; fall back to simple additive noise
    if hasattr(noise_scheduler, "add_noise"):
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    else:
        noisy_latents = latents + noise

    # Match transformer dtype and scale timesteps to [0,1]
    tdtype = getattr(transformer, "dtype", noisy_latents.dtype)
    hidden_states = noisy_latents.to(device=device, dtype=tdtype)
    t_scaled = timesteps.to(device, dtype=tdtype) / float(
        noise_scheduler.config.num_train_timesteps
    )

    # ===== Encode text prompts using SD3.5 encode_prompt (like previous working version) =====
    with torch.no_grad():
        # Source prompt: "person"
        txt_person, pooled_person = pipe_encode_prompts(
            pipe, prompts_person, device=device, max_sequence_length=256
        )
        # Target prompt: "woman"
        txt_woman, pooled_woman = pipe_encode_prompts(
            pipe, prompts_woman, device=device, max_sequence_length=256
        )

        txt_person = txt_person.to(device, dtype=tdtype)
        pooled_person = pooled_person.to(device, dtype=tdtype)
        txt_woman = txt_woman.to(device, dtype=tdtype)
        pooled_woman = pooled_woman.to(device, dtype=tdtype)

    # ========== STEP 1: HEADS OFF, TARGET PROMPT (WOMAN) ==========
    transformer.use_external_heads = False
    with torch.no_grad():
        noise_ref = transformer(
            hidden_states=hidden_states,
            timestep=t_scaled,
            encoder_hidden_states=txt_woman,
            pooled_projections=pooled_woman,
            return_dict=False,
        )[0]

    # ========== STEP 2: HEADS ON, SOURCE PROMPT (PERSON) ==========
    transformer.use_external_heads = True
    noise_pred = transformer(
        hidden_states=hidden_states,
        timestep=t_scaled,
        encoder_hidden_states=txt_person,
        pooled_projections=pooled_person,
        return_dict=False,
    )[0]

    # Corrected loss: (person + heads) ≈ woman
    loss = F.mse_loss(noise_pred.float(), noise_ref.float())
    return loss


# ============================================================
# Main Training Loop
# ============================================================

def main():
    torch.manual_seed(cfg.seed)

    os.makedirs(cfg.output_dir, exist_ok=True)

    print(f"Loading SD3.5 model: {cfg.model_id}")
    pipe = StableDiffusion3Pipeline.from_pretrained(
        cfg.model_id,
        torch_dtype=torch.float32,
        use_safetensors=True,
    )
    device = "cuda"
    pipe.to(device)

    # Flag to enable/disable external heads inside attention hooks
    pipe.transformer.use_external_heads = True

    # Freeze everything except external heads
    pipe.vae.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    pipe.transformer.requires_grad_(False)

    # Determine head shape
    first_block = pipe.transformer.transformer_blocks[0]
    attn = first_block.attn

    # SD3.5 correct attributes
    H = getattr(attn, "heads", None)
    D = getattr(attn, "head_dim", None)

    if H is None or D is None:
        # fallback to config
        H = pipe.transformer.config.num_attention_heads
        D = pipe.transformer.config.attention_head_dim

    print(f"[SD3.5] detected heads={H}, head_dim={D}")

    # Determine sequence length at training resolution
    # SD3.5 VAE has downsample factor 16 → latent spatial size = H/16, W/16
    latent_H = cfg.resolution // 16
    latent_W = cfg.resolution // 16
    seq_len = latent_H * latent_W  # e.g., 32 × 32 = 1024

    external_heads_module = ExternalHeads(
        num_layers=len(pipe.transformer.transformer_blocks),
        num_heads=H,
        seq_len=seq_len,
        head_dim=D,
    ).to(device)

    # Install residual hooks for learning
    install_external_heads_hooks(
        pipe,
        external_heads_module.external_heads,
        target_layers=list(range(len(pipe.transformer.transformer_blocks))),
        device=device,
        dtype=torch.bfloat16,
    )

    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        cfg.model_id, subfolder="scheduler"
    )
    # Initialize FlowMatch timesteps (needed for sampling like previous trainer)
    if hasattr(noise_scheduler, "set_timesteps"):
        noise_scheduler.set_timesteps(
            noise_scheduler.config.num_train_timesteps, device=device
        )

    # Dataset
    transform = transforms.Compose([
        transforms.Resize((cfg.resolution, cfg.resolution)),
        transforms.CenterCrop(cfg.resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])

    def sd3_tokenizer_wrapper(text_list):
        out = pipe.tokenizer_3(
            text_list,
            padding=False,
            truncation=True,
            max_length=pipe.tokenizer_3.model_max_length
        )
        return out["input_ids"]

    dataset = TrainingDataset(
        image_folder=cfg.train_data_dir,
        transform=transform,
        tokenizer=sd3_tokenizer_wrapper,
        max_concept_length=128,
        select="random",
    )

    dataloader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
        collate_fn=lambda samples: {
            # samples is a list of items from TrainingDataset
            # each item is typically (image_tensor, label/targets/whatever)
            "pixel_values": torch.stack([x[0] for x in samples]).to(torch.float32)
        },
    )

    head_importance = torch.zeros(
        external_heads_module.external_heads.shape[:2],
        dtype=torch.float32,
        device=device,
    )
    importance_ema_decay = 0.99

    def update_head_importance(param, ema_decay=importance_ema_decay):
        if param.grad is None:
            return
        grad = param.grad.detach()
        per_head_norm = grad.pow(2).sum(dim=(2, 3)).sqrt()
        head_importance.mul_(ema_decay).add_(per_head_norm * (1 - ema_decay))

    optimizer = torch.optim.AdamW(
        external_heads_module.parameters(),
        lr=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
    )

    print("\n=== START TRAINING (woman − person) ===\n")

    for epoch in range(cfg.num_epochs):
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device=device, dtype=torch.float32)

            prompts_person = [cfg.prompt_person] * pixel_values.size(0)
            prompts_woman = [cfg.prompt_woman] * pixel_values.size(0)

            loss = training_step(
                pipe,
                noise_scheduler,
                pixel_values,
                prompts_person,
                prompts_woman,
                device,
            )

            optimizer.zero_grad()
            loss.backward()
            update_head_importance(external_heads_module.external_heads)
            optimizer.step()

        print(f"Epoch {epoch}: loss = {loss.item():.4f}")


        if (epoch + 1) % cfg.save_every == 0:
            # ------------------------------------------------------------
            # Save external head weights
            # ------------------------------------------------------------
            save_path = os.path.join(cfg.output_dir, f"external_heads_epoch_{epoch + 1}.pt")
            torch.save(external_heads_module.state_dict(), save_path)
            print(f"💾 Saved external heads checkpoint → {save_path}")

            # ------------------------------------------------------------
            # Save head importance JSON (EMA gradient norms per head)
            # ------------------------------------------------------------
            imp_dict = {}
            arr = head_importance.detach().cpu().tolist()
            for i, layer_idx in enumerate(range(head_importance.shape[0])):
                for h in range(head_importance.shape[1]):
                    imp_dict[f"layer_{layer_idx}_head_{h}"] = arr[i][h]

            json_path = os.path.join(cfg.output_dir, f"head_importance_epoch_{epoch + 1}.json")
            with open(json_path, "w") as f:
                json.dump(imp_dict, f, indent=2)

            print(f"📊 Saved head importance JSON → {json_path}")

    print("\n=== DONE TRAINING ===")


if __name__ == "__main__":
    main()
