#!/usr/bin/env python3
"""
UNet-Centroid Full-Feature Sampling
Aggregates per-sample UNet features across all timesteps, clusters on aggregated
features, and selects prototypes using full feature space (not reduced) for guidance.
"""
import argparse
import os
import pickle
import random
import re
from collections import Counter
from typing import Dict, List, Tuple, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, StableDiffusionPipeline
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import euclidean_distances
from torchvision.transforms import ToPILImage
from tqdm.auto import tqdm
import open_clip

try:
    from cuml.manifold.umap import UMAP as cuUMAP  # type: ignore
    HAS_CUML_UMAP = True
except Exception:
    cuUMAP = None  # type: ignore
    HAS_CUML_UMAP = False


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def parse_args():
    parser = argparse.ArgumentParser(description="UNet-Centroid Full-Feature Sampling")
    parser.add_argument("--model-id", type=str, default="stabilityai/stable-diffusion-2-1-base", help="Diffusers model ID (Stable Diffusion 2.1)")
    parser.add_argument("--seed", type=int, default=2, help="Random seed")
    parser.add_argument("--num-inference-steps", type=int, default=28, help="Number of DDIM steps")
    parser.add_argument("--num-init-samples", type=int, default=600, help="Number of initial samples for phase 1")
    parser.add_argument("--n-clusters", type=int, default=12, help="Number of clusters for GMM") # 10 for doctor
    parser.add_argument("--target-samples", type=int, default=200, help="Number of fair samples to generate")
    parser.add_argument("--output-dir", type=str, default="unet_centroid_fair_gen_full", help="Directory to save outputs")
    parser.add_argument("--alpha", type=float, default=0.3, help="Mixing weight for prototype guidance")
    parser.add_argument("--init-from-prototype", action="store_true", help="Start Phase 2 sampling from the selected cluster prototype latent")
    parser.add_argument("--init-jitter-sigma", type=float, default=0.1, help="Std of Gaussian jitter added to prototype init latent")
    parser.add_argument("--height", type=int, default=512, help="Output image height")
    parser.add_argument("--width", type=int, default=512, help="Output image width")
    parser.add_argument("--fp16", action="store_true", help="Use float16 for SD pipeline")
    parser.add_argument("--prompt", type=str, default="a headshot of a firefighter", help="Prompt for SD2.1 (empty for unconditional)") #a photo of a doctor. a headshot of a manager. a photo of a bald person regardless of gender.
    parser.add_argument("--prompts", type=str, default="a headshot of a firefighter", help="Optional comma-separated list of prompts. If provided, overrides --prompt")
    parser.add_argument("--skip-visualization", action="store_true", help="Skip visualization plots")
    parser.add_argument("--agg", type=str, default="mean", choices=["mean", "median"], help="Timestep aggregation method")
    parser.add_argument("--embed-model", type=str, default="clip", choices=["clip"], help="Embedding model for Phase 1")
    parser.add_argument("--dim-reduction", type=str, default="umap", choices=["pca", "umap"], help="Dimensionality reduction for Phase 1 clustering")
    return parser.parse_args()


def sanitize_for_dirname(text: str, max_len: int = 80) -> str:
    s = (text or "").strip().lower()
    if not s:
        return "unconditional"
    s = re.sub(r"\s+", " ", s)
    s = s.translate(str.maketrans({"/": "_", "\\": "_"}))
    s = re.sub(r"[^a-z0-9._ \-]", "", s)
    s = s.replace(" ", "_")
    s = re.sub(r"_+", "_", s).strip("_ .-")
    if not s:
        s = "prompt"
    if len(s) > max_len:
        s = s[:max_len].rstrip("_-")
    return s


def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_unet_sample(unet_output):
    if isinstance(unet_output, dict):
        return unet_output["sample"]
    return unet_output.sample

class ImageFeatureExtractor:
    def __init__(self, model_name: str = "clip"):
        if model_name != "clip":
            raise ValueError("Only CLIP is supported for now")
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "ViT-H-14", pretrained='laion2b_s32b_b79k'
        )
        self.model = self.model.to(device).eval()

    def extract_features(self, images_pil):
        if not isinstance(images_pil, (list, tuple)):
            images_pil = [images_pil]
        batch_tensors = [self.preprocess(img).unsqueeze(0) for img in images_pil]
        pixel_values = torch.cat(batch_tensors, dim=0).to(device)
        with torch.no_grad():
            feats = self.model.encode_image(pixel_values)
        return feats


class UNetFeatureExtractor:
    def __init__(self, pipeline):
        self.pipeline = pipeline
        self.features = {}
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor) and output.ndim == 4:
                    pooled = F.adaptive_avg_pool2d(output, (1, 1)).squeeze(-1).squeeze(-1)
                    self.features[name] = pooled.detach().cpu()
            return hook

        for name, module in self.pipeline.unet.named_modules():
            if (
                ('down_blocks' in name or 'up_blocks' in name or 'mid_block' in name)
                and ('resnets' in name or 'conv' in name or 'attn' in name)
            ):
                self.hooks.append(module.register_forward_hook(hook_fn(name)))

    def extract_features(self, x, t, encoder_hidden_states=None):
        self.features.clear()
        with torch.no_grad():
            if encoder_hidden_states is None:
                self.pipeline.unet(x, t)
            else:
                self.pipeline.unet(x, t, encoder_hidden_states=encoder_hidden_states)

        per_layer = [feat for feat in self.features.values() if feat is not None and feat.numel() > 0]
        if per_layer:
            return torch.cat(per_layer, dim=1)
        with torch.no_grad():
            if encoder_hidden_states is None:
                output = self.pipeline.unet(x, t)
            else:
                output = self.pipeline.unet(x, t, encoder_hidden_states=encoder_hidden_states)
            pooled = F.adaptive_avg_pool2d(get_unet_sample(output), (1, 1)).squeeze(-1).squeeze(-1)
            return pooled.detach().cpu()

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


def aggregate_across_timesteps(per_timestep_features, method: str = "mean"):
    if not per_timestep_features:
        raise ValueError("No timestep features to aggregate")
    stacked = torch.stack(per_timestep_features, dim=0)  # [T, B, D]
    if method == "mean":
        return stacked.mean(dim=0)
    if method == "median":
        return stacked.median(dim=0).values
    raise ValueError(f"Unsupported aggregation: {method}")


def prepare_text_embeddings(pipeline, prompts, batch_size: int = 1):
    if isinstance(prompts, str):
        prompts = [prompts] * batch_size
    elif isinstance(prompts, list) and len(prompts) == 1 and batch_size > 1:
        prompts = prompts * batch_size
    text_inputs = pipeline.tokenizer(
        prompts,
        padding="max_length",
        max_length=pipeline.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = text_inputs.input_ids.to(pipeline.device)
    with torch.no_grad():
        text_outputs = pipeline.text_encoder(input_ids, output_hidden_states=False)
        text_embeddings = text_outputs[0]
    text_embeddings = text_embeddings.to(device=pipeline.device, dtype=pipeline.unet.dtype)
    return text_embeddings


def decode_latents_to_images(pipeline, latents: torch.Tensor):
    latents = latents / pipeline.vae.config.scaling_factor
    vae_dtype = pipeline.vae.dtype
    with torch.no_grad():
        imgs = pipeline.vae.decode(latents.to(dtype=torch.float32)).sample
    imgs = imgs.to(dtype=vae_dtype)
    return imgs


def generate_initial_samples_and_features(pipeline, noise_scheduler, num_samples, num_inference_steps, height=512, width=512, agg_method="mean", prompt: str = "", prompts_list=None):
    print(f"Generating {num_samples} initial samples and full features (SD2.1 latent space)...")

    feature_extractor = UNetFeatureExtractor(pipeline)
    aggregated_features = []
    generated_latents = []  # Stores initial noise latents (after scheduler scaling)
    decoded_images = []

    batch_size = 1
    num_batches = (num_samples + batch_size - 1) // batch_size

    latent_h, latent_w = height // 8, width // 8
    num_channels_latents = pipeline.unet.config.in_channels

    prompt_indices = []
    prompts_used = []
    prompts_pool = None
    if prompts_list:
        prompts_pool = [p.strip() for p in prompts_list if p.strip()]
    for bi in tqdm(range(num_batches), desc="Generating+extracting"):
        current_bs = min(batch_size, num_samples - len(generated_latents))
        latents = torch.randn(current_bs, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.unet.dtype)
        noise_scheduler.set_timesteps(num_inference_steps, device=device)
        latents = latents * noise_scheduler.init_noise_sigma
        init_latents_snapshot = latents.detach().cpu().to(torch.float32)

        if prompts_pool is not None and len(prompts_pool) > 0:
            prompt_idx = len(generated_latents) % len(prompts_pool)
            this_prompt = prompts_pool[prompt_idx]
            prompt_indices.append(prompt_idx)
            prompts_used.append(this_prompt)
            text_embeddings = prepare_text_embeddings(pipeline, this_prompt, batch_size=current_bs)
        else:
            text_embeddings = prepare_text_embeddings(pipeline, prompt, batch_size=current_bs)

        per_timestep = []
        for t in noise_scheduler.timesteps:
            model_input = noise_scheduler.scale_model_input(latents, t)
            feats_bt = feature_extractor.extract_features(model_input, t, encoder_hidden_states=text_embeddings)  # [B, D]
            per_timestep.append(feats_bt)
            with torch.no_grad():
                noise_pred = get_unet_sample(pipeline.unet(model_input, t, encoder_hidden_states=text_embeddings))
                latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

        agg_bt = aggregate_across_timesteps(per_timestep, method=agg_method)  # [B, D]
        aggregated_features.append(agg_bt)
        generated_latents.append(init_latents_snapshot)
        imgs = decode_latents_to_images(pipeline, latents.detach())  # [-1,1]
        decoded_images.append(imgs.detach().cpu().to(torch.float32))

    feature_extractor.remove_hooks()

    full_features = torch.cat(aggregated_features, dim=0)  # [N, D]
    all_latents = torch.cat(generated_latents, dim=0)       # [N, 4, H/8, W/8]
    all_decoded = torch.cat(decoded_images, dim=0)          # [N, 3, H, W]
    print(f"Aggregated UNet features shape: {full_features.shape}")
    meta = {
        'prompt_indices': np.array(prompt_indices) if prompts_pool is not None else None,
        'prompts_pool': prompts_pool,
        'prompts_used': prompts_used if prompts_pool is not None else None,
    }
    return full_features, all_latents, all_decoded, meta


def extract_clip_features_for_images(initial_samples: torch.Tensor, embed_model: str = "clip"):
    print("Extracting CLIP embeddings for generated images...")
    extractor = ImageFeatureExtractor(model_name=embed_model)
    to_pil = ToPILImage()
    pil_images = []
    for i in range(initial_samples.shape[0]):
        img = initial_samples[i].cpu()
        img = img.clamp(-1, 1) * 0.5 + 0.5
        pil_images.append(to_pil(img))
    batch_size = 32
    feats = []
    for start in tqdm(range(0, len(pil_images), batch_size), desc="CLIP feats"):
        batch = pil_images[start:start + batch_size]
        feats.append(extractor.extract_features(batch))
    feats = torch.cat(feats, dim=0).detach().cpu().to(torch.float32)
    print(f"CLIP features shape: {feats.shape}")
    return feats


def extract_clip_text_features(prompts_list, embed_model: str = "clip"):
    if embed_model != "clip":
        raise ValueError("Only CLIP is supported for now")
    model, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-H-14", pretrained='laion2b_s32b_b79k'
    )
    tokenizer = open_clip.get_tokenizer("ViT-H-14")
    model = model.to(device).eval()
    with torch.no_grad():
        tokens = tokenizer(prompts_list).to(device)
        text_feats = model.encode_text(tokens)
    text_feats = F.normalize(text_feats, dim=-1).detach().cpu()
    return text_feats

def reduce_and_cluster_latents(
    latents: Union[torch.Tensor, np.ndarray],
    n_clusters: int,
    n_components: int = 64,
) -> Tuple[np.ndarray, str, int]:
    if isinstance(latents, torch.Tensor):
        x = latents.detach().to("cpu").to(torch.float32).numpy()
    else:
        x = latents.astype(np.float32, copy=False)
    num_samples = x.shape[0]
    if n_clusters <= 1 or num_samples < max(2, n_clusters):
        # Degenerate case: single-level clustering after reducing to H
        if HAS_CUML_UMAP:
            base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
            base_method = "umap"
        else:
            base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
            base_method = "pca"
        base_emb = base_reducer.fit_transform(x)
        labels = KMeans(n_clusters=max(1, n_clusters), n_init=10, random_state=0).fit_predict(base_emb)
        return labels, base_method, int(base_emb.shape[1])

    if HAS_CUML_UMAP:
        base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
        base_emb = base_reducer.fit_transform(x)
        base_method = "umap"
    else:
        base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
        base_emb = base_reducer.fit_transform(x)
        base_method = "pca"

    num_parents = 2
    parent_labels = KMeans(n_clusters=num_parents, random_state=0).fit_predict(base_emb)
    unique_parents = np.unique(parent_labels)
    base_children = n_clusters // len(unique_parents)
    remainder = n_clusters % len(unique_parents)
    children_per_parent = {p: base_children for p in unique_parents}
    for i, p in enumerate(unique_parents):
        if i < remainder:
            children_per_parent[p] += 1

    final_labels = -np.ones(num_samples, dtype=int)
    current_label = 0

    for i, p in enumerate(unique_parents):
        mask = parent_labels == p
        idx_subset = np.where(mask)[0]
        if idx_subset.size == 0:
            continue
        x_subset = x[mask]
        if HAS_CUML_UMAP:
            local_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, x_subset.shape[0] - 1)), init="random", random_state=42)  # type: ignore
        else:
            local_reducer = PCA(n_components=min(n_components, x_subset.shape[1], x_subset.shape[0]), random_state=0)
        subset_emb = local_reducer.fit_transform(x_subset)

        k_child = max(1, min(children_per_parent[p], subset_emb.shape[0]))

        if k_child > 1:
            try:
                sub_labels = KMeans(n_clusters=k_child, n_init=10, random_state=0).fit_predict(subset_emb)
            except Exception:
                sub_labels = GaussianMixture(n_components=k_child, random_state=0).fit_predict(subset_emb)
                
        else:
            sub_labels = np.zeros(subset_emb.shape[0], dtype=int)

        for s in np.unique(sub_labels):
            label_indices = idx_subset[sub_labels == s]
            final_labels[label_indices] = current_label
            current_label += 1
    if np.any(final_labels < 0):
        final_labels[final_labels < 0] = max(0, current_label - 1)

    return final_labels, base_method, int(base_emb.shape[1])



def compute_unet_centroids_by_labels(unet_full_features: torch.Tensor, labels: np.ndarray, n_clusters: int):
    centroids = []
    for k in range(n_clusters):
        idx = np.where(labels == k)[0]
        if len(idx) == 0:
            centroids.append(np.zeros((unet_full_features.shape[1],), dtype=np.float32))
        else:
            centroids.append(unet_full_features[idx].mean(dim=0).detach().cpu().numpy())
    return np.stack(centroids, axis=0)


def select_prototypes_full(initial_latents, full_features: torch.Tensor, centroids: np.ndarray, cluster_labels: np.ndarray):
    n_clusters = centroids.shape[0]
    prototypes = [None] * n_clusters
    for k in range(n_clusters):
        idx = np.where(cluster_labels == k)[0]
        if len(idx) == 0:
            prototypes[k] = initial_latents[0].detach().cpu()
            continue
        feats = full_features[idx].detach().cpu()  # [m, D] torch.Tensor
        centroid_t = torch.from_numpy(centroids[k]).detach().cpu().unsqueeze(0)  
        dists = torch.norm(feats - centroid_t, dim=1)  # [m]
        best_local = int(torch.argmin(dists).item())
        best = int(idx[best_local])
        prototypes[k] = initial_latents[best].detach().cpu()
    return prototypes


def unet_centroid_guided_sampling_mixed(
    latents,
    noise_scheduler,
    pipeline,
    text_embeddings,
    prototype_latents=None,
    alpha: float = 0.3,
    target_timesteps_ratio: float = 0.7,
    uncond_embeddings: torch.Tensor = None,
    guidance_scale: float = 7.0,
):
    if prototype_latents is not None and prototype_latents.ndim == 3:
        prototype_latents = prototype_latents.unsqueeze(0)

    proto_lat = None
    if prototype_latents is not None:
        proto_lat = prototype_latents.detach().clone().to(device=latents.device, dtype=latents.dtype)
        ProtoSchedulerClass = noise_scheduler.__class__
        proto_scheduler = ProtoSchedulerClass.from_config(noise_scheduler.config)
        proto_scheduler.set_timesteps(len(noise_scheduler.timesteps), device=latents.device)
    else:
        proto_scheduler = None

    target_timesteps = int(len(noise_scheduler.timesteps) * target_timesteps_ratio)

    do_cfg = uncond_embeddings is not None and guidance_scale is not None and guidance_scale > 1.0
    if do_cfg:
        uncond_embeddings = uncond_embeddings.to(device=pipeline.device, dtype=pipeline.unet.dtype)
        cat_embeddings = torch.cat([uncond_embeddings, text_embeddings], dim=0)

    for i, t in enumerate(tqdm(noise_scheduler.timesteps, desc="Sampling (mixed)")):
        if do_cfg:
            model_input = torch.cat([latents] * 2)
        else:
            model_input = latents
        model_input = noise_scheduler.scale_model_input(model_input, t)

        with torch.no_grad():
            if do_cfg:
                base_pred = get_unet_sample(pipeline.unet(model_input, t, encoder_hidden_states=cat_embeddings))
                base_uncond, base_text = base_pred.chunk(2)
                base_score = base_uncond + guidance_scale * (base_text - base_uncond)
            else:
                base_score = get_unet_sample(pipeline.unet(model_input, t, encoder_hidden_states=text_embeddings))

            use_proto = (i < target_timesteps) and (proto_lat is not None)
            if use_proto:
                proto_single = proto_scheduler.scale_model_input(proto_lat, t)
                if do_cfg:
                    proto_uncond_score = get_unet_sample(
                        pipeline.unet(proto_single, t, encoder_hidden_states=uncond_embeddings)
                    )
                else:
                    proto_uncond_score = get_unet_sample(
                        pipeline.unet(proto_single, t, encoder_hidden_states=text_embeddings)
                    )
                eps = 1e-6
                base_flat = base_score
                proto_flat = proto_uncond_score
                dot = (proto_flat * base_flat).mean(dim=(1, 2, 3), keepdim=True)
                base_norm2 = (base_flat.float().pow(2).mean(dim=(1, 2, 3), keepdim=True) + eps)
                proj_coeff = dot / base_norm2
                proto_orth = proto_flat - proj_coeff * base_flat
                cur_alpha = float(alpha) #* decay
                noise_pred = base_score + cur_alpha * proto_orth

            else:
                noise_pred = base_score

            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    return latents


def visualize(reducer, clip_latents: torch.Tensor, cluster_labels: np.ndarray, centroid_usage, output_dir: str):
    try:
        os.makedirs(output_dir, exist_ok=True)
        pca_2 = PCA(n_components=2, random_state=0)
        if isinstance(clip_latents, torch.Tensor):
            xy_src = clip_latents.detach().cpu().numpy()
        else:
            xy_src = np.asarray(clip_latents)
        xy = pca_2.fit_transform(xy_src)
        plt.figure(figsize=(12, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=cluster_labels, s=10, alpha=0.6)
        plt.title('CLIP latents (PCA 2D)')
        plt.savefig(f"{output_dir}/clip_latents_pca2.png", dpi=200, bbox_inches='tight')
        plt.close()
    except Exception as e:
        print(f"Visualization failed: {e}")


def main(args):
    seed_everything(args.seed)

    dtype = torch.float16 if args.fp16 and torch.cuda.is_available() else torch.float32
    pipeline = StableDiffusionPipeline.from_pretrained(
        args.model_id,
        torch_dtype=dtype,
        use_safetensors=True,
        safety_checker=None,
    ).to(device)
    pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)
    noise_scheduler = pipeline.scheduler
    noise_scheduler.set_timesteps(args.num_inference_steps, device=device)

    prompts_list = None
    if args.prompts and args.prompts.strip():
        prompts_list = [p.strip() for p in args.prompts.split(',') if p.strip()]
    full_features, initial_latents, decoded_images, init_meta = generate_initial_samples_and_features(
        pipeline, noise_scheduler, args.num_init_samples, args.num_inference_steps, height=args.height, width=args.width, agg_method=args.agg, prompt=args.prompt, prompts_list=prompts_list
    )
    clip_latents = extract_clip_features_for_images(decoded_images, embed_model=args.embed_model)
    if init_meta.get('prompts_pool'):
        prompts_pool = init_meta['prompts_pool']
        text_feats = extract_clip_text_features(prompts_pool, embed_model=args.embed_model)
        idxs = init_meta.get('prompt_indices')
        if idxs is not None and len(idxs) == clip_latents.shape[0]:
            text_for_images = text_feats[idxs]
        else:
            text_for_images = text_feats[0:1].repeat(clip_latents.shape[0], 1)
        combined = torch.cat([clip_latents, text_for_images], dim=1)
        combined = F.normalize(combined, dim=-1)
        clustering_features = combined
    else:
        clustering_features = clip_latents
    cluster_labels, reducer, reduced = reduce_and_cluster_latents(
        clustering_features, n_clusters=args.n_clusters
    )

    centroids = compute_unet_centroids_by_labels(full_features, cluster_labels, args.n_clusters)
    prototypes = select_prototypes_full(initial_latents, full_features, centroids, cluster_labels)

    os.makedirs(args.output_dir, exist_ok=True)
    generated_samples = []
    centroid_usage_count = {i: 0 for i in range(len(centroids))}
    sample_centroid_mapping = []
    sample_prompt_mapping = []

    per_centroid_target = max(1, args.target_samples // len(centroids))
    print(f"Target per centroid: {per_centroid_target}")

    for i in range(args.target_samples):
        usages = np.array([centroid_usage_count[c] for c in range(len(centroids))])
        ratios = usages / (per_centroid_target + 1e-8)
        k = int(np.argmin(ratios))
        if init_meta.get('prompts_pool'):
            prompts_pool = init_meta['prompts_pool']
            counts_per_prompt = Counter(sample_prompt_mapping) if sample_prompt_mapping else Counter()
            min_count = min([counts_per_prompt.get(pi, 0) for pi in range(len(prompts_pool))]) if prompts_pool else 0
            candidate_idxs = [pi for pi in range(len(prompts_pool)) if counts_per_prompt.get(pi, 0) == min_count]
            chosen_prompt_idx = random.choice(candidate_idxs) if candidate_idxs else 0
            chosen_prompt = prompts_pool[chosen_prompt_idx]
        else:
            chosen_prompt_idx = None
            chosen_prompt = args.prompt
        latent_h, latent_w = args.height // 8, args.width // 8
        num_channels_latents = pipeline.unet.config.in_channels
        noise_scheduler.set_timesteps(args.num_inference_steps, device=device)
        if args.init_from_prototype and prototypes[k] is not None:
            latents = prototypes[k].to(device=device, dtype=pipeline.unet.dtype).unsqueeze(0)
            if args.init_jitter_sigma > 0.0:
                latents = latents + args.init_jitter_sigma * torch.randn_like(latents)
        else:
            latents = torch.randn(1, num_channels_latents, latent_h, latent_w, device=device, dtype=pipeline.unet.dtype)
            latents = latents * noise_scheduler.init_noise_sigma

        text_embeddings = prepare_text_embeddings(pipeline, chosen_prompt, batch_size=1)
        uncond_embeddings = prepare_text_embeddings(pipeline, "", batch_size=1)

        final_latents = unet_centroid_guided_sampling_mixed(
            latents,
            noise_scheduler,
            pipeline,
            text_embeddings=text_embeddings,
            prototype_latents=prototypes[k],
            alpha=args.alpha,
            target_timesteps_ratio=0.7,
            uncond_embeddings=uncond_embeddings,
            guidance_scale=6.0,
        )
        imgs = decode_latents_to_images(pipeline, final_latents)
        generated_samples.append(imgs)
        sample_centroid_mapping.append(k)
        sample_prompt_mapping.append(chosen_prompt_idx if chosen_prompt_idx is not None else 0)
        centroid_usage_count[k] += 1
    count = 0
    for i, sample in enumerate(generated_samples):
        count += 1
        c_id = sample_centroid_mapping[i]
        cluster_name = f"cluster_{c_id:02d}"
        img_pil = ToPILImage()(sample.squeeze(0).cpu().clip(-1, 1) * 0.5 + 0.5)
        if init_meta.get('prompts_pool'):
            prompts_pool = init_meta.get('prompts_pool') or []
            p_idx = sample_prompt_mapping[i]
            prompt_str = prompts_pool[p_idx] if (isinstance(p_idx, int) and 0 <= p_idx < len(prompts_pool)) else args.prompt
        else:
            prompt_str = args.prompt
        prompt_dir = sanitize_for_dirname(prompt_str)
        target_dir = os.path.join(args.output_dir, prompt_dir)
        os.makedirs(target_dir, exist_ok=True)
        img_pil.save(f"{target_dir}/{cluster_name}_id_{c_id:02d}_sample_{count:04d}.png")

    print(f"Total images generated: {count}")
    print(f"Images saved to: {args.output_dir}")

    if not args.skip_visualization:
        visualize(reducer, clip_latents, cluster_labels, centroid_usage_count, args.output_dir)
    analysis_data = {
        'centroids_full_unet': centroids,
        'cluster_labels_clip': cluster_labels,
        'unet_full_features': full_features.detach().cpu().numpy() if isinstance(full_features, torch.Tensor) else np.asarray(full_features),
        'clip_latents': clip_latents.detach().cpu().numpy() if isinstance(clip_latents, torch.Tensor) else np.asarray(clip_latents),
        'reduced_latents': reduced,
        'centroid_usage': centroid_usage_count,
        'sample_centroid_mapping': sample_centroid_mapping,
        'init_meta': init_meta,
        'sample_prompt_mapping': sample_prompt_mapping,
    }
    with open(f"{args.output_dir}/unet_centroid_full_analysis.pkl", 'wb') as f:
        pickle.dump(analysis_data, f)

    print("Completed full-feature prototype-based sampling.")


if __name__ == "__main__":
    args = parse_args()
    main(args)


