
import argparse
import os
import pickle
import random
from collections import Counter

import numpy as np
import torch
import torch.nn.functional as F
from diffusers import StableDiffusion3Pipeline, FlowMatchEulerDiscreteScheduler
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import euclidean_distances
from torchvision.transforms import ToPILImage
from tqdm.auto import tqdm
import open_clip

try:
    import umap.umap_ as umap
except Exception:
    umap = None


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


def parse_args():
    parser = argparse.ArgumentParser(description="UNet-Centroid Full-Feature Sampling")
    parser.add_argument("--model-id", type=str, default="stabilityai/stable-diffusion-3-medium-diffusers", help="Diffusers model ID (Stable Diffusion 3 Medium)")
    parser.add_argument("--seed", type=int, default=18, help="Random seed") #0 doc  2
    parser.add_argument("--num-inference-steps", type=int, default=28, help="Number of DDIM steps")
    parser.add_argument("--num-init-samples", type=int, default=200, help="Number of initial samples for phase 1")
    parser.add_argument("--n-clusters", type=int, default=10, help="Number of clusters for GMM")
    parser.add_argument("--target-samples", type=int, default=200, help="Number of fair samples to generate")
    parser.add_argument("--output-dir", type=str, default="SD-3-generated-image", help="Directory to save outputs")
    parser.add_argument("--alpha", type=float, default=0.3, help="Mixing weight for prototype guidance")
    parser.add_argument("--init-from-prototype", action="store_true", help="Start Phase 2 sampling from the selected cluster prototype latent")
    parser.add_argument("--init-jitter-sigma", type=float, default=0.25, help="Std of Gaussian jitter added to prototype init latent")
    parser.add_argument("--height", type=int, default=512, help="Output image height")
    parser.add_argument("--width", type=int, default=512, help="Output image width")
    parser.add_argument("--fp16", action="store_true", help="Use float16 for SD pipeline")
    parser.add_argument("--prompt", type=str, default="", help="Prompt for SD3 (empty for unconditional)") #a headshot of a manager. a photo of a bald either gender.
    parser.add_argument("--prompts", type=str, default="A photo of a single firefighter.", help="Optional comma-separated list of prompts. If provided, overrides --prompt") #a headshot of a manager regardless of gender.
    parser.add_argument("--skip-visualization", action="store_true", help="Skip visualization plots")
    parser.add_argument("--agg", type=str, default="mean", choices=["mean", "median"], help="Timestep aggregation method")
    parser.add_argument("--embed-model", type=str, default="clip", choices=["clip"], help="Embedding model for Phase 1")
    parser.add_argument("--dim-reduction", type=str, default="umap", choices=["pca", "umap"], help="Dimensionality reduction for Phase 1 clustering")
    parser.add_argument("--save-one-step-preview", action="store_true", help="Save a 1-step decoded image for sanity check")
    parser.add_argument("--save-one-step-generated", action="store_true", help="Run a 1-step generation and save the final image")
    parser.add_argument("--save-initial-first", action="store_true", help="Save the first image from initial sampling phase")
    parser.add_argument("--init-guidance-scale", type=float, default=7, help="CFG guidance scale used during initial 50-image generation")
    parser.add_argument("--save-baseline-pipeline", action="store_true", help="Save a baseline image using pipeline() for quality comparison")
    parser.add_argument("--use-pipeline-for-initial", action="store_true", help="Use pipeline() with callback to generate initial images and capture features")
    return parser.parse_args()


def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_model_sample(model_output):
    if isinstance(model_output, dict):
        return model_output.get("sample", model_output.get("pred", model_output))
    return getattr(model_output, "sample", model_output)


def scale_model_input_safe(scheduler, sample, timestep):
    scaler = getattr(scheduler, "scale_model_input", None)
    if scaler is None:
        return sample
    try:
        return scaler(sample, timestep)
    except TypeError:
        return scaler(sample)


def ensure_timestep_1d(t, device):
    if isinstance(t, torch.Tensor):
        if t.dim() == 0:
            return t.to(device=device, dtype=torch.long).unsqueeze(0)
        return t.to(device=device, dtype=torch.long)
    return torch.tensor([int(t)], device=device, dtype=torch.long)

class ImageFeatureExtractor:
    def __init__(self, model_name: str = "clip"):
        if model_name != "clip":
            raise ValueError("Only CLIP is supported for now")
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "ViT-H-14", pretrained='laion2b_s32b_b79k'
        )
        self.model = self.model.to(device).eval()

    def extract_features(self, images_pil):
        if not isinstance(images_pil, (list, tuple)):
            images_pil = [images_pil]
        batch_tensors = [self.preprocess(img).unsqueeze(0) for img in images_pil]
        pixel_values = torch.cat(batch_tensors, dim=0).to(device)
        with torch.no_grad():
            feats = self.model.encode_image(pixel_values)
        feats = F.normalize(feats, dim=-1).detach().cpu()
        return feats


class TransformerFeatureExtractor:

    def __init__(self, pipeline):
        self.pipeline = pipeline
        self.features = {}
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    if output.ndim == 4:
                        pooled = F.adaptive_avg_pool2d(output, (1, 1)).squeeze(-1).squeeze(-1)
                    elif output.ndim == 3:
                        pooled = output.mean(dim=1)
                    else:
                        return
                    self.features[name] = pooled.detach().cpu()
            return hook

        # Hook a broad set of transformer submodules
        for name, module in self.pipeline.transformer.named_modules():
            if any(key in name for key in ["attn", "mlp", "norm", "proj", "blocks", "transformer_blocks"]):
                try:
                    self.hooks.append(module.register_forward_hook(hook_fn(name)))
                except Exception:
                    pass

    def extract_features(self, x, t, encoder_hidden_states=None, pooled_projections=None):
        self.features.clear()
        t = ensure_timestep_1d(t, x.device)
        x = x.to(dtype=self.pipeline.transformer.dtype)
        with torch.no_grad():
            if encoder_hidden_states is None:
                self.pipeline.transformer(hidden_states=x, timestep=t)
            else:
                self.pipeline.transformer(hidden_states=x, timestep=t, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projections)

        per_layer = [feat for feat in self.features.values() if feat is not None and feat.numel() > 0]
        if per_layer:
            return torch.cat(per_layer, dim=1)

        with torch.no_grad():
            if encoder_hidden_states is None:
                output = self.pipeline.transformer(hidden_states=x, timestep=t)
            else:
                output = self.pipeline.transformer(hidden_states=x, timestep=t, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projections)
            sample = get_model_sample(output)
            if sample.ndim == 4:
                pooled = F.adaptive_avg_pool2d(sample, (1, 1)).squeeze(-1).squeeze(-1)
            elif sample.ndim == 3:
                pooled = sample.mean(dim=1)
            else:
                pooled = sample
            return pooled.detach().cpu()

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


def aggregate_across_timesteps(per_timestep_features, method: str = "mean"):
    if not per_timestep_features:
        raise ValueError("No timestep features to aggregate")
    stacked = torch.stack(per_timestep_features, dim=0)  # [T, B, D]
    if method == "mean":
        return stacked.mean(dim=0)
    if method == "median":
        return stacked.median(dim=0).values
    raise ValueError(f"Unsupported aggregation: {method}")


def prepare_sd3_conditioning(pipeline, prompts, batch_size: int = 1):
    if isinstance(prompts, str):
        prompts = [prompts] * batch_size
    elif isinstance(prompts, list) and len(prompts) == 1 and batch_size > 1:
        prompts = prompts * batch_size
    with torch.no_grad():
        enc_out = pipeline.encode_prompt(
            prompt=prompts,
            prompt_2=prompts,
            prompt_3=prompts,
            device=pipeline.device,
            num_images_per_prompt=1,
            do_classifier_free_guidance=False,
        )
        if isinstance(enc_out, (tuple, list)):
            prompt_embeds = enc_out[0]
            pooled_prompt_embeds = enc_out[2] if len(enc_out) > 2 else enc_out[1]
        elif isinstance(enc_out, dict):
            prompt_embeds = enc_out.get("prompt_embeds")
            pooled_prompt_embeds = enc_out.get("pooled_prompt_embeds")
        else:
            prompt_embeds = getattr(enc_out, "prompt_embeds", enc_out)
            pooled_prompt_embeds = getattr(enc_out, "pooled_prompt_embeds", None)
    return (
        prompt_embeds.to(device=pipeline.device, dtype=pipeline.transformer.dtype),
        pooled_prompt_embeds.to(device=pipeline.device, dtype=pipeline.transformer.dtype),
    )


def prepare_sd3_uncond(pipeline, batch_size: int = 1):
    blanks = [""] * batch_size
    with torch.no_grad():
        enc_out = pipeline.encode_prompt(
            prompt=blanks,
            prompt_2=blanks,
            prompt_3=blanks,
            device=pipeline.device,
            num_images_per_prompt=1,
            do_classifier_free_guidance=False,
        )
        if isinstance(enc_out, (tuple, list)):
            uncond_embeds = enc_out[0]
            uncond_pooled = enc_out[2] if len(enc_out) > 2 else enc_out[1]
        elif isinstance(enc_out, dict):
            uncond_embeds = enc_out.get("prompt_embeds")
            uncond_pooled = enc_out.get("pooled_prompt_embeds")
        else:
            uncond_embeds = getattr(enc_out, "prompt_embeds", enc_out)
            uncond_pooled = getattr(enc_out, "pooled_prompt_embeds", None)
    return (
        uncond_embeds.to(device=pipeline.device, dtype=pipeline.transformer.dtype),
        uncond_pooled.to(device=pipeline.device, dtype=pipeline.transformer.dtype),
    )


def decode_latents_to_images(pipeline, latents: torch.Tensor):
    latents = latents / pipeline.vae.config.scaling_factor
    with torch.no_grad():
        imgs = pipeline.vae.decode(latents.to(dtype=pipeline.vae.dtype)).sample
    return imgs


def save_one_step_generated_image(pipeline, noise_scheduler, height, width, prompt: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    vae_sf = getattr(pipeline, "vae_scale_factor", 8)
    latent_h, latent_w = height // vae_sf, width // vae_sf
    num_channels_latents = pipeline.transformer.config.in_channels
    latents = torch.randn(1, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.transformer.dtype)
    noise_scheduler.set_timesteps(1, device=device)
    sigma = getattr(noise_scheduler, "init_noise_sigma", None)
    if sigma is not None:
        latents = latents * sigma

    prompt_embeds, pooled_prompt_embeds = prepare_sd3_conditioning(pipeline, prompt, batch_size=1)

    t = noise_scheduler.timesteps[0]
    model_input = scale_model_input_safe(noise_scheduler, latents, t)
    t_1d = ensure_timestep_1d(t, model_input.device)
    with torch.no_grad():
        out = pipeline.transformer(hidden_states=model_input.to(dtype=pipeline.transformer.dtype), timestep=t_1d, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds)
        noise_pred = get_model_sample(out)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    img = decode_latents_to_images(pipeline, latents)
    img_pil = ToPILImage()(img.squeeze(0).cpu().clip(-1, 1) * 0.5 + 0.5)
    path = os.path.join(output_dir, "one_step_generated.png")
    img_pil.save(path)
    print(f"Saved 1-step generated image to: {path}")

def save_baseline_pipeline_image(pipeline, prompt: str, output_dir: str, height: int, width: int, num_inference_steps: int, guidance_scale: float):
    os.makedirs(output_dir, exist_ok=True)
    with torch.no_grad():
        image = pipeline(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
        ).images[0]
    path = os.path.join(output_dir, "baseline_pipeline.png")
    image.save(path)
    print(f"Saved baseline pipeline image to: {path}")


def generate_initial_samples_and_features_with_pipeline(
    pipeline,
    num_samples: int,
    num_inference_steps: int,
    height: int,
    width: int,
    agg_method: str,
    prompt: str,
    prompts_list=None,
    output_dir: str = "",
    save_initial_first: bool = False,
    guidance_scale: float = 7.0,
):
    print(f"(Pipeline) Generating {num_samples} initial samples and full features...")
    feature_extractor = TransformerFeatureExtractor(pipeline)
    aggregated_features = []
    generated_latents = []
    decoded_images = []

    vae_sf = getattr(pipeline, "vae_scale_factor", 8)
    latent_h, latent_w = height // vae_sf, width // vae_sf
    num_channels_latents = pipeline.transformer.config.in_channels

    prompts_pool = None
    prompt_indices = []
    prompts_used = []
    if prompts_list:
        prompts_pool = [p.strip() for p in prompts_list if p.strip()]

    to_pil = ToPILImage()

    for i in tqdm(range(num_samples), desc="Pipeline gen+extract"):
        latents0 = torch.randn(1, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.transformer.dtype)
        init_snapshot = latents0.detach().cpu().to(torch.float32)

        if prompts_pool is not None and len(prompts_pool) > 0:
            this_idx = i % len(prompts_pool)
            this_prompt = prompts_pool[this_idx]
            prompt_indices.append(this_idx)
            prompts_used.append(this_prompt)
        else:
            this_prompt = prompt

        per_timestep_feats = []

        def on_step_end(pipe, step, timestep, callback_kwargs):
            per_layer = [feat for feat in feature_extractor.features.values() if feat is not None and feat.numel() > 0]
            if per_layer:
                pooled = torch.cat(per_layer, dim=1).detach().cpu()
                per_timestep_feats.append(pooled)
            return callback_kwargs

        with torch.no_grad():
            out = pipeline(
                this_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale if guidance_scale and guidance_scale > 1.0 else None,
                height=height,
                width=width,
                latents=latents0,
                callback_on_step_end=on_step_end,
                callback_on_step_end_tensor_inputs=["latents"],
            )
            img_pil = out.images[0]

        if per_timestep_feats:
            agg_bt = aggregate_across_timesteps(per_timestep_feats, method=agg_method)
        else:
            agg_bt = torch.zeros((1, 1), dtype=torch.float32)
        aggregated_features.append(agg_bt)
        generated_latents.append(init_snapshot) # Save initial noise latents, not the final denoised latents
        np_img = np.array(img_pil).astype(np.float32) / 255.0
        img_t = torch.from_numpy(np_img).permute(2, 0, 1)
        img_t = img_t * 2.0 - 1.0
        decoded_images.append(img_t.unsqueeze(0))

        if save_initial_first and i == 0:
            try:
                os.makedirs(output_dir or "initial_first", exist_ok=True)
                img_pil.save(f"{output_dir or 'initial_first'}/initial_first.png")
                print(f"Saved initial first image to: {output_dir or 'initial_first'}/initial_first.png")
            except Exception as e:
                print(f"Failed to save initial first image: {e}")

    feature_extractor.remove_hooks()

    full_features = torch.cat(aggregated_features, dim=0)
    all_latents = torch.cat(generated_latents, dim=0)
    all_decoded = torch.cat(decoded_images, dim=0)
    print(f"(Pipeline) Aggregated UNet features shape: {full_features.shape}")
    meta = {
        'prompt_indices': np.array(prompt_indices) if prompts_pool is not None else None,
        'prompts_pool': prompts_pool,
        'prompts_used': prompts_used if prompts_pool is not None else None,
    }
    return full_features, all_latents, all_decoded, meta

def generate_initial_samples_and_features(pipeline, noise_scheduler, num_samples, num_inference_steps, height=512, width=512, agg_method="mean", prompt: str = "", prompts_list=None, save_one_step_preview: bool = False, output_dir: str = "", save_initial_first: bool = False, init_guidance_scale: float = 0.0):
    print(f"Generating {num_samples} initial samples and full features (SD3 latent space)...")

    feature_extractor = TransformerFeatureExtractor(pipeline)
    aggregated_features = []
    generated_latents = []  # Stores initial noise latents (after scheduler scaling)
    decoded_images = []

    batch_size = 1
    num_batches = (num_samples + batch_size - 1) // batch_size

    vae_sf = getattr(pipeline, "vae_scale_factor", None)
    if vae_sf is None:
        # Fallback if attribute missing
        vae_sf = 8
    latent_h, latent_w = height // vae_sf, width // vae_sf
    num_channels_latents = pipeline.transformer.config.in_channels

    prompt_indices = []
    prompts_used = []
    prompts_pool = None
    if prompts_list:
        prompts_pool = [p.strip() for p in prompts_list if p.strip()]
    for bi in tqdm(range(num_batches), desc="Generating+extracting"):
        current_bs = min(batch_size, num_samples - len(generated_latents))
        latents = torch.randn(current_bs, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.transformer.dtype)
        noise_scheduler.set_timesteps(num_inference_steps, device=device)
        sigma = getattr(noise_scheduler, "init_noise_sigma", None)
        if sigma is not None:
            latents = latents * sigma
        init_latents_snapshot = latents.detach().cpu().to(torch.float32)

        if prompts_pool is not None and len(prompts_pool) > 0:
            prompt_idx = len(generated_latents) % len(prompts_pool)
            this_prompt = prompts_pool[prompt_idx]
            prompt_indices.append(prompt_idx)
            prompts_used.append(this_prompt)
            prompt_embeds, pooled_prompt_embeds = prepare_sd3_conditioning(pipeline, this_prompt, batch_size=current_bs)
        else:
            prompt_embeds, pooled_prompt_embeds = prepare_sd3_conditioning(pipeline, prompt, batch_size=current_bs)
        do_cfg = init_guidance_scale is not None and init_guidance_scale > 1.0
        if do_cfg:
            uncond_embeds, uncond_pooled = prepare_sd3_uncond(pipeline, batch_size=current_bs)
            cat_embeddings = torch.cat([uncond_embeds, prompt_embeds], dim=0)
            cat_pooled = torch.cat([uncond_pooled, pooled_prompt_embeds], dim=0)

        per_timestep = []
        one_step_saved = False
        for t in noise_scheduler.timesteps:
            model_input = scale_model_input_safe(noise_scheduler, latents, t)
            t_1d = ensure_timestep_1d(t, model_input.device)
            feats_bt = feature_extractor.extract_features(model_input, t, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds)  # [B, D]
            per_timestep.append(feats_bt)
            with torch.no_grad():
                if do_cfg:
                    mi = torch.cat([model_input] * 2)
                    out = pipeline.transformer(hidden_states=mi.to(dtype=pipeline.transformer.dtype), timestep=t_1d, encoder_hidden_states=cat_embeddings, pooled_projections=cat_pooled)
                    pred = get_model_sample(out)
                    uncond, text = pred.chunk(2)
                    noise_pred = uncond + init_guidance_scale * (text - uncond)
                else:
                    out = pipeline.transformer(hidden_states=model_input.to(dtype=pipeline.transformer.dtype), timestep=t_1d, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds)
                    noise_pred = get_model_sample(out)
                latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

            if save_one_step_preview and not one_step_saved and bi == 0:
                try:
                    os.makedirs(output_dir or "one_step_preview", exist_ok=True)
                    preview = decode_latents_to_images(pipeline, latents.detach())
                    img_pil = ToPILImage()(preview.squeeze(0).cpu().clip(-1, 1) * 0.5 + 0.5)
                    fname = f"{output_dir or 'one_step_preview'}/one_step_preview_initial.png"
                    img_pil.save(fname)
                    print(f"Saved 1-step preview to: {fname}")
                except Exception as e:
                    print(f"Failed to save 1-step preview: {e}")
                one_step_saved = True

        agg_bt = aggregate_across_timesteps(per_timestep, method=agg_method)  # [B, D]
        aggregated_features.append(agg_bt)
        generated_latents.append(init_latents_snapshot)

        imgs = decode_latents_to_images(pipeline, latents.detach())
        decoded_images.append(imgs.detach().cpu().to(torch.float32))
        if save_initial_first and bi == 0:
            try:
                os.makedirs(output_dir or "initial_first", exist_ok=True)
                img0 = imgs[0].detach().cpu().clip(-1, 1) * 0.5 + 0.5
                ToPILImage()(img0).save(f"{output_dir or 'initial_first'}/initial_first.png")
                print(f"Saved initial first image to: {output_dir or 'initial_first'}/initial_first.png")
            except Exception as e:
                print(f"Failed to save initial first image: {e}")

    feature_extractor.remove_hooks()

    full_features = torch.cat(aggregated_features, dim=0)  # [N, D]
    all_latents = torch.cat(generated_latents, dim=0)       # [N, 4, H/8, W/8]
    all_decoded = torch.cat(decoded_images, dim=0)          # [N, 3, H, W]
    print(f"Aggregated UNet features shape: {full_features.shape}")
    meta = {
        'prompt_indices': np.array(prompt_indices) if prompts_pool is not None else None,
        'prompts_pool': prompts_pool,
        'prompts_used': prompts_used if prompts_pool is not None else None,
    }
    return full_features, all_latents, all_decoded, meta


def extract_clip_features_for_images(initial_samples: torch.Tensor, embed_model: str = "clip"):
    print("Extracting CLIP embeddings for generated images...")
    extractor = ImageFeatureExtractor(model_name=embed_model)
    to_pil = ToPILImage()
    pil_images = []
    for i in range(initial_samples.shape[0]):
        img = initial_samples[i].cpu()
        img = img.clamp(-1, 1) * 0.5 + 0.5
        pil_images.append(to_pil(img))
    batch_size = 32
    feats = []
    for start in tqdm(range(0, len(pil_images), batch_size), desc="CLIP feats"):
        batch = pil_images[start:start + batch_size]
        feats.append(extractor.extract_features(batch))
    feats = torch.cat(feats, dim=0)
    print(f"CLIP features shape: {feats.shape}")
    return feats


def extract_clip_text_features(prompts_list, embed_model: str = "clip"):
    if embed_model != "clip":
        raise ValueError("Only CLIP is supported for now")
    model, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-H-14", pretrained='laion2b_s32b_b79k'
    )
    tokenizer = open_clip.get_tokenizer("ViT-H-14")
    model = model.to(device).eval()
    with torch.no_grad():
        tokens = tokenizer(prompts_list).to(device)
        text_feats = model.encode_text(tokens)
    text_feats = F.normalize(text_feats, dim=-1).detach().cpu()
    return text_feats


def reduce_and_cluster_latents(latents: torch.Tensor, n_clusters: int, method: str = "pca", n_components: int = 32):
    x = latents.numpy()
    num_samples = x.shape[0]
    if n_clusters <= 1 or num_samples < max(2, n_clusters):
        if method == "umap" and umap is not None:
            reducer = umap.UMAP(n_components=min(10, n_components), random_state=0)
        else:
            if method == "umap":
                print("UMAP not available. Falling back to PCA.")
            reducer = PCA(n_components=min(2, n_components), random_state=0)
        reduced = reducer.fit_transform(x)
        kmeans = KMeans(n_clusters=max(1, n_clusters), n_init=10, random_state=0)
        labels = kmeans.fit_predict(reduced)
        print(f"Cluster distribution: {Counter(labels)}")
        return labels, reducer, reduced
    if method == "umap" and umap is not None:
        base_reducer = umap.UMAP(n_components=2, random_state=0)
    else:
        if method == "umap":
            print("UMAP not available. Using PCA for base reduction.")
        base_reducer = PCA(n_components=2, random_state=0)
    base_2d = base_reducer.fit_transform(x)
    try:
        tsne_2d = TSNE(n_components=2, random_state=42, init="random").fit_transform(x)
    except Exception:
        tsne_2d = base_2d

    num_parents = min(2, n_clusters)
    parent_kmeans = KMeans(n_clusters=num_parents, n_init=10, random_state=0)
    parent_labels = parent_kmeans.fit_predict(base_2d)
    unique_parents = np.unique(parent_labels)
    base_children = n_clusters // len(unique_parents)
    remainder = n_clusters % len(unique_parents)
    children_per_parent = {}
    for i, p in enumerate(unique_parents):
        children_per_parent[p] = base_children + (1 if i < remainder else 0)

    final_labels = -np.ones(num_samples, dtype=int)
    current_label = 0

    for i, p in enumerate(unique_parents):
        mask = parent_labels == p
        idx_subset = np.where(mask)[0]
        if idx_subset.size == 0:
            continue

        subset_tsne = tsne_2d[mask]
        if umap is not None:
            local_reducer = umap.UMAP(n_components=2, random_state=0)
            subset_emb = local_reducer.fit_transform(subset_tsne)
        else:
            local_reducer = PCA(n_components=2, random_state=0)
            subset_emb = local_reducer.fit_transform(subset_tsne)

        k_child = max(1, children_per_parent[p])
        k_child = min(k_child, subset_emb.shape[0])

        if k_child > 1:
            try:
                sub_labels = GaussianMixture(n_components=k_child, random_state=0).fit_predict(subset_emb)
            except Exception:
                sub_labels = KMeans(n_clusters=k_child, n_init=10, random_state=0).fit_predict(subset_emb)
        else:
            sub_labels = np.zeros(subset_emb.shape[0], dtype=int)

        for s in np.unique(sub_labels):
            label_indices = idx_subset[sub_labels == s]
            final_labels[label_indices] = current_label
            current_label += 1

    if np.any(final_labels < 0):
        final_labels[final_labels < 0] = max(0, current_label - 1)

    print(f"Cluster distribution: {Counter(final_labels)}")
    return final_labels, base_reducer, base_2d


def compute_unet_centroids_by_labels(unet_full_features: torch.Tensor, labels: np.ndarray, n_clusters: int):
    centroids = []
    for k in range(n_clusters):
        idx = np.where(labels == k)[0]
        if len(idx) == 0:
            centroids.append(np.zeros((unet_full_features.shape[1],), dtype=np.float32))
        else:
            centroids.append(unet_full_features[idx].mean(dim=0).numpy())
    return np.stack(centroids, axis=0)


def select_prototypes_full(initial_latents, full_features: torch.Tensor, centroids: np.ndarray, cluster_labels: np.ndarray):
    n_clusters = centroids.shape[0]
    prototypes = [None] * n_clusters
    for k in range(n_clusters):
        idx = np.where(cluster_labels == k)[0]
        if len(idx) == 0:
            prototypes[k] = initial_latents[0].detach().cpu()
            continue
        feats = full_features[idx].detach().cpu()  # [m, D] torch.Tensor
        centroid_t = torch.from_numpy(centroids[k]).detach().cpu().unsqueeze(0)  # [1, D]
        dists = torch.norm(feats - centroid_t, dim=1)  # [m]
        best_local = int(torch.argmin(dists).item())
        best = int(idx[best_local])
        prototypes[k] = initial_latents[best].detach().cpu()
    return prototypes


def unet_centroid_guided_sampling_mixed(
    latents,
    noise_scheduler,
    pipeline,
    text_embeddings,
    prototype_latents=None,
    alpha: float = 0.3,
    target_timesteps_ratio: float = 0.7,
    uncond_embeddings: torch.Tensor = None,
    guidance_scale: float = 7.0,
):
    if prototype_latents is not None and prototype_latents.ndim == 3:
        prototype_latents = prototype_latents.unsqueeze(0)

    proto_lat = None
    if prototype_latents is not None:
        proto_lat = prototype_latents.detach().clone().to(device=latents.device, dtype=latents.dtype)
        ProtoSchedulerClass = noise_scheduler.__class__
        proto_scheduler = ProtoSchedulerClass.from_config(noise_scheduler.config)
        proto_scheduler.set_timesteps(len(noise_scheduler.timesteps), device=latents.device)
    else:
        proto_scheduler = None

    target_timesteps = int(len(noise_scheduler.timesteps) * target_timesteps_ratio)

    do_cfg = uncond_embeddings is not None and guidance_scale is not None and guidance_scale > 1.0
    if do_cfg:
        uncond_embeds, uncond_pooled = uncond_embeddings
        text_embeds, pooled_text = text_embeddings
        uncond_embeds = uncond_embeds.to(device=pipeline.device, dtype=pipeline.transformer.dtype)
        text_embeds = text_embeds.to(device=pipeline.device, dtype=pipeline.transformer.dtype)
        cat_embeddings = torch.cat([uncond_embeds, text_embeds], dim=0)
        cat_pooled = torch.cat([uncond_pooled, pooled_text], dim=0)
    else:
        text_embeds, pooled_text = text_embeddings

    for i, t in enumerate(tqdm(noise_scheduler.timesteps, desc="Sampling (mixed)")):
        if do_cfg:
            model_input = torch.cat([latents] * 2)
        else:
            model_input = latents
        model_input = scale_model_input_safe(noise_scheduler, model_input, t)
        t_1d = ensure_timestep_1d(t, model_input.device)

        with torch.no_grad():
            if do_cfg:
                base_pred = get_model_sample(pipeline.transformer(hidden_states=model_input.to(dtype=pipeline.transformer.dtype), timestep=t_1d, encoder_hidden_states=cat_embeddings, pooled_projections=cat_pooled))
                base_uncond, base_text = base_pred.chunk(2)
                base_score = base_uncond + guidance_scale * (base_text - base_uncond)
            else:
                base_score = get_model_sample(pipeline.transformer(hidden_states=model_input.to(dtype=pipeline.transformer.dtype), timestep=t_1d, encoder_hidden_states=text_embeds, pooled_projections=pooled_text))

            use_proto = (i < target_timesteps) and (proto_lat is not None)
            if use_proto:
                proto_single = scale_model_input_safe(proto_scheduler, proto_lat, t)
                t_1d_proto = ensure_timestep_1d(t, proto_single.device)
                if do_cfg:
                    proto_uncond_score = get_model_sample(
                        pipeline.transformer(hidden_states=proto_single.to(dtype=pipeline.transformer.dtype), timestep=t_1d_proto, encoder_hidden_states=uncond_embeds, pooled_projections=uncond_pooled)
                    )
                else:
                    proto_uncond_score = get_model_sample(
                        pipeline.transformer(hidden_states=proto_single.to(dtype=pipeline.transformer.dtype), timestep=t_1d_proto, encoder_hidden_states=text_embeds, pooled_projections=pooled_text)
                    )
                eps = 1e-6
                base_flat = base_score
                proto_flat = proto_uncond_score
                dot = (proto_flat * base_flat).mean(dim=(1, 2, 3), keepdim=True)
                base_norm2 = (base_flat.float().pow(2).mean(dim=(1, 2, 3), keepdim=True) + eps)
                proj_coeff = dot / base_norm2
                proto_orth = proto_flat - proj_coeff * base_flat
                cur_alpha = float(alpha) 
                noise_pred = base_score + cur_alpha * proto_orth
            else:
                noise_pred = base_score

            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    return latents


def visualize(reducer, clip_latents: torch.Tensor, cluster_labels: np.ndarray, centroid_usage, output_dir: str):
    try:
        os.makedirs(output_dir, exist_ok=True)
        pca_2 = PCA(n_components=2, random_state=0)
        xy = pca_2.fit_transform(clip_latents.numpy())
        plt.figure(figsize=(12, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=cluster_labels, s=10, alpha=0.6)
        plt.title('CLIP latents (PCA 2D)')
        plt.savefig(f"{output_dir}/clip_latents_pca2.png", dpi=200, bbox_inches='tight')
        plt.close()
    except Exception as e:
        print(f"Visualization failed: {e}")


def main(args):
    seed_everything(args.seed)

    dtype = torch.float16 if args.fp16 and torch.cuda.is_available() else torch.float32
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        args.model_id,
        torch_dtype=torch.float16,
        use_safetensors=True,
        safety_checker=None,
    ).to(device)
    pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipeline.scheduler.config)
    noise_scheduler = pipeline.scheduler
    noise_scheduler.set_timesteps(args.num_inference_steps, device=device)

    prompts_list = None
    if args.prompts and args.prompts.strip():
        prompts_list = [p.strip() for p in args.prompts.split(',') if p.strip()]
    if args.use_pipeline_for_initial:
        full_features, initial_latents, decoded_images, init_meta = generate_initial_samples_and_features_with_pipeline(
            pipeline,
            num_samples=args.num_init_samples,
            num_inference_steps=args.num_inference_steps,
            height=args.height,
            width=args.width,
            agg_method=args.agg,
            prompt=args.prompt,
            prompts_list=prompts_list,
            output_dir=args.output_dir,
            save_initial_first=args.save_initial_first,
            guidance_scale=args.init_guidance_scale,
        )
    else:
        full_features, initial_latents, decoded_images, init_meta = generate_initial_samples_and_features(
            pipeline,
            noise_scheduler,
            args.num_init_samples,
            args.num_inference_steps,
            height=args.height,
            width=args.width,
            agg_method=args.agg,
            prompt=args.prompt,
            prompts_list=prompts_list,
            save_one_step_preview=args.save_one_step_preview,
            output_dir=args.output_dir,
            save_initial_first=args.save_initial_first,
            init_guidance_scale=args.init_guidance_scale,
        )

    if args.save_one_step_generated:
        save_one_step_generated_image(
            pipeline,
            noise_scheduler,
            height=args.height,
            width=args.width,
            prompt=args.prompt,
            output_dir=args.output_dir,
        )

    clip_latents = extract_clip_features_for_images(decoded_images, embed_model=args.embed_model)
    if init_meta.get('prompts_pool'):
        prompts_pool = init_meta['prompts_pool']
        text_feats = extract_clip_text_features(prompts_pool, embed_model=args.embed_model)
        idxs = init_meta.get('prompt_indices')
        if idxs is not None and len(idxs) == clip_latents.shape[0]:
            text_for_images = text_feats[idxs]
        else:
            text_for_images = text_feats[0:1].repeat(clip_latents.shape[0], 1)
        combined = torch.cat([clip_latents, text_for_images], dim=1)
        combined = F.normalize(combined, dim=-1)
        clustering_features = combined
    else:
        clustering_features = clip_latents
    cluster_labels, reducer, reduced = reduce_and_cluster_latents(
        clustering_features, n_clusters=args.n_clusters, method=args.dim_reduction
    )

    centroids = compute_unet_centroids_by_labels(full_features, cluster_labels, args.n_clusters)
    prototypes = select_prototypes_full(initial_latents, full_features, centroids, cluster_labels)

    os.makedirs(args.output_dir, exist_ok=True)
    generated_samples = []
    centroid_usage_count = {i: 0 for i in range(len(centroids))}
    sample_centroid_mapping = []
    sample_prompt_mapping = []

    per_centroid_target = max(1, args.target_samples // len(centroids))
    print(f"Target per centroid: {per_centroid_target}")

    for i in range(args.target_samples):
        usages = np.array([centroid_usage_count[c] for c in range(len(centroids))])
        ratios = usages / (per_centroid_target + 1e-8)
        k = int(np.argmin(ratios))
        if init_meta.get('prompts_pool'):
            prompts_pool = init_meta['prompts_pool']
            counts_per_prompt = Counter(sample_prompt_mapping) if sample_prompt_mapping else Counter()
            min_count = min([counts_per_prompt.get(pi, 0) for pi in range(len(prompts_pool))]) if prompts_pool else 0
            candidate_idxs = [pi for pi in range(len(prompts_pool)) if counts_per_prompt.get(pi, 0) == min_count]
            chosen_prompt_idx = random.choice(candidate_idxs) if candidate_idxs else 0
            chosen_prompt = prompts_pool[chosen_prompt_idx]
        else:
            chosen_prompt_idx = None
            chosen_prompt = args.prompt
        vae_sf = getattr(pipeline, "vae_scale_factor", None)
        if vae_sf is None:
            vae_sf = 8
        latent_h, latent_w = args.height // vae_sf, args.width // vae_sf
        num_channels_latents = pipeline.transformer.config.in_channels
        noise_scheduler.set_timesteps(args.num_inference_steps, device=device)
        if args.init_from_prototype and prototypes[k] is not None:
            latents = prototypes[k].to(device=device, dtype=pipeline.transformer.dtype).unsqueeze(0)
            if args.init_jitter_sigma > 0.0:
                latents = latents + args.init_jitter_sigma * torch.randn_like(latents)
        else:
            latents = torch.randn(1, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.transformer.dtype)
            sigma = getattr(noise_scheduler, "init_noise_sigma", None)
            if sigma is not None:
                latents = latents * sigma

        text_embeddings = prepare_sd3_conditioning(pipeline, chosen_prompt, batch_size=1)
        uncond_embeddings = prepare_sd3_uncond(pipeline, batch_size=1)

        final_latents = unet_centroid_guided_sampling_mixed(
            latents,
            noise_scheduler,
            pipeline,
            text_embeddings=text_embeddings,
            prototype_latents=prototypes[k],
            alpha=args.alpha,
            target_timesteps_ratio=0.7,
            uncond_embeddings=uncond_embeddings,
            guidance_scale=6.0,
        )
        imgs = decode_latents_to_images(pipeline, final_latents)
        generated_samples.append(imgs)
        sample_centroid_mapping.append(k)
        sample_prompt_mapping.append(chosen_prompt_idx if chosen_prompt_idx is not None else 0)
        centroid_usage_count[k] += 1
    count = 0
    for i, sample in enumerate(generated_samples):
        count += 1
        c_id = sample_centroid_mapping[i]
        cluster_name = f"cluster_{c_id:02d}"
        img_pil = ToPILImage()(sample.squeeze(0).cpu().clip(-1, 1) * 0.5 + 0.5)
        if init_meta.get('prompts_pool'):
            p_idx = sample_prompt_mapping[i]
            p_name = f"p{p_idx:02d}"
            img_pil.save(f"{args.output_dir}/{p_name}_{cluster_name}_id_{c_id:02d}_sample_{count:04d}.png")
        else:
            img_pil.save(f"{args.output_dir}/{cluster_name}_id_{c_id:02d}_sample_{count:04d}.png")

    print(f"Total images generated: {count}")
    print(f"Images saved to: {args.output_dir}")
    print("Completed full-feature prototype-based sampling.")


if __name__ == "__main__":
    args = parse_args()
    main(args)


