import argparse
import os
import pickle
import random
from collections import Counter

import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import euclidean_distances
from torchvision.transforms import ToPILImage
from tqdm.auto import tqdm
import open_clip
from typing import Dict, List, Tuple, Optional, Union
try:
    from cuml.manifold.umap import UMAP as cuUMAP  # type: ignore
    HAS_CUML_UMAP = True
except Exception:
    cuUMAP = None  # type: ignore
    HAS_CUML_UMAP = False


device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


def parse_args():
    parser = argparse.ArgumentParser(description="FID_compute")
    parser.add_argument("--model-id", type=str, default="google/ddpm-celebahq-256", help="Diffusers model ID")
    parser.add_argument("--seed", type=int, default=4, help="Random seed") #2
    parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of DDIM steps")
    parser.add_argument("--num-init-samples", type=int, default= 600, help="Number of initial samples for phase 1") # 600
    parser.add_argument("--n-clusters", type=int, default=12, help="Number of clusters for GMM")
    parser.add_argument("--target-samples", type=int, default=10000, help="Number of fair samples to generate")
    parser.add_argument("--output-dir", type=str, default="CelebA_Fair_Sampling_Full_Features", help="Directory to save outputs")
    parser.add_argument("--alpha", type=float, default=0.07, help="Mixing weight for prototype guidance")
    parser.add_argument("--image-size", type=int, default=256, help="Image resolution (assumes square)")
    parser.add_argument("--skip-visualization", action="store_true", help="Skip visualization plots")
    parser.add_argument("--agg", type=str, default="mean", choices=["mean", "median"], help="Timestep aggregation method")
    parser.add_argument("--embed-model", type=str, default="clip", choices=["clip"], help="Embedding model for Phase 1")
    parser.add_argument("--dim-reduction", type=str, default="ump", choices=["pca", "umap"], help="Dimensionality reduction for Phase 1 clustering")
    return parser.parse_args()


def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_unet_sample(unet_output):
    if isinstance(unet_output, dict):
        return unet_output["sample"]
    return unet_output.sample

class ImageFeatureExtractor:
    def __init__(self, model_name: str = "clip"):
        if model_name != "clip":
            raise ValueError("Only CLIP is supported for now")
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "ViT-H-14", pretrained='laion2b_s32b_b79k'
        )
        self.model = self.model.to(device).eval()

    def extract_features(self, images_pil):
        if not isinstance(images_pil, (list, tuple)):
            images_pil = [images_pil]
        batch_tensors = [self.preprocess(img).unsqueeze(0) for img in images_pil]
        pixel_values = torch.cat(batch_tensors, dim=0).to(device)
        with torch.no_grad():
            feats = self.model.encode_image(pixel_values)
        return feats


class UNetFeatureExtractor:

    def __init__(self, pipeline):
        self.pipeline = pipeline
        self.features = {}
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def hook_fn(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor) and output.ndim == 4:
                    pooled = F.adaptive_avg_pool2d(output, (1, 1)).squeeze(-1).squeeze(-1)
                    self.features[name] = pooled.detach().cpu()
            return hook

        for name, module in self.pipeline.unet.named_modules():
            if (
                ('down_blocks' in name or 'up_blocks' in name or 'mid_block' in name)
                and ('resnets' in name or 'conv' in name or 'attn' in name)
            ):
                self.hooks.append(module.register_forward_hook(hook_fn(name)))

    def extract_features(self, x, t):
        self.features.clear()
        with torch.no_grad():
            self.pipeline.unet(x, t)

        per_layer = [feat for feat in self.features.values() if feat is not None and feat.numel() > 0]
        if per_layer:
            return torch.cat(per_layer, dim=1)
        with torch.no_grad():
            output = self.pipeline.unet(x, t)
            pooled = F.adaptive_avg_pool2d(get_unet_sample(output), (1, 1)).squeeze(-1).squeeze(-1)
            return pooled.detach().cpu()

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks.clear()


def aggregate_across_timesteps(per_timestep_features, method: str = "mean"):
    if not per_timestep_features:
        raise ValueError("No timestep features to aggregate")
    stacked = torch.stack(per_timestep_features, dim=0)  # [T, B, D]
    if method == "mean":
        return stacked.mean(dim=0)
    if method == "median":
        return stacked.median(dim=0).values
    raise ValueError(f"Unsupported aggregation: {method}")


def generate_initial_samples_and_features(pipeline, noise_scheduler, num_samples, image_size=256, agg_method="mean"):
    print(f"Generating {num_samples} initial samples and full features...")

    feature_extractor = UNetFeatureExtractor(pipeline)
    aggregated_features = []
    generated_samples = []

    batch_size = 1
    num_batches = (num_samples + batch_size - 1) // batch_size

    for _ in tqdm(range(num_batches), desc="Generating+extracting"):
        current_bs = min(batch_size, num_samples - len(generated_samples))
        x = torch.randn(current_bs, 3, image_size, image_size, device=device)

        per_timestep = []
        for t in noise_scheduler.timesteps:
            model_input = noise_scheduler.scale_model_input(x, t)
            feats_bt = feature_extractor.extract_features(model_input, t)  # [B, D]
            per_timestep.append(feats_bt)
            with torch.no_grad():
                noise_pred = get_unet_sample(pipeline.unet(model_input, t))
                x = noise_scheduler.step(noise_pred, t, x).prev_sample

        agg_bt = aggregate_across_timesteps(per_timestep, method=agg_method)  # [B, D]
        aggregated_features.append(agg_bt)
        generated_samples.append(x.detach().cpu())

    feature_extractor.remove_hooks()

    full_features = torch.cat(aggregated_features, dim=0)  # [N, D]
    all_samples = torch.cat(generated_samples, dim=0)       # [N, 3, H, W]
    print(f"Aggregated UNet features shape: {full_features.shape}")
    return full_features, all_samples


def extract_clip_features_for_images(initial_samples: torch.Tensor, embed_model: str = "clip"):
    print("Extracting CLIP embeddings for generated images...")
    extractor = ImageFeatureExtractor(model_name=embed_model)
    to_pil = ToPILImage()
    pil_images = []
    for i in range(initial_samples.shape[0]):
        img = initial_samples[i].cpu()
        img = img.clamp(-1, 1) * 0.5 + 0.5
        pil_images.append(to_pil(img))
    batch_size = 32
    feats = []
    for start in tqdm(range(0, len(pil_images), batch_size), desc="CLIP feats"):
        batch = pil_images[start:start + batch_size]
        feats.append(extractor.extract_features(batch))
    feats = torch.cat(feats, dim=0)
    print(f"CLIP features shape: {feats.shape}")
    return feats

"""
def reduce_and_cluster_latents(latents: torch.Tensor, n_clusters: int, method: str = "pca", n_components: int = 32):
    #Reduce CLIP latents then cluster. Returns labels, reducer object, reduced array.
    x = latents.numpy()
    if method == "umap":
        if umap is None:
            print("UMAP not available. Falling back to PCA.")
            reducer = PCA(n_components=n_components, random_state=0)
            reduced = reducer.fit_transform(x)
        else:
            reducer = umap.UMAP(n_components=min(10, n_components), random_state=0)
            reduced = reducer.fit_transform(x)
    else:
        reducer = PCA(n_components=n_components, random_state=0)
        reduced = reducer.fit_transform(x)

    print(f"Reduced dims: {reduced.shape}")
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=0)
    labels = kmeans.fit_predict(reduced)
    print(f"Cluster distribution: {Counter(labels)}")
    return labels, reducer, reduced
"""

def reduce_and_cluster_latents(
    latents: Union[torch.Tensor, np.ndarray],
    n_clusters: int,
    n_components: int = 64,
) -> Tuple[np.ndarray, str, int]:
    if isinstance(latents, torch.Tensor):
        x = latents.detach().to("cpu").to(torch.float32).numpy()
    else:
        x = latents.astype(np.float32, copy=False)
    num_samples = x.shape[0]
    if n_clusters <= 1 or num_samples < max(2, n_clusters):
        if HAS_CUML_UMAP:
            base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
            base_method = "umap"
        else:
            base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
            base_method = "pca"
        base_emb = base_reducer.fit_transform(x)
        labels = KMeans(n_clusters=max(1, n_clusters), n_init=10, random_state=0).fit_predict(base_emb)
        return labels, base_method, int(base_emb.shape[1])

    if HAS_CUML_UMAP:
        base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
        base_emb = base_reducer.fit_transform(x)
        base_method = "umap"
    else:
        base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
        base_emb = base_reducer.fit_transform(x)
        base_method = "pca"
    num_parents = 2
    parent_labels = KMeans(n_clusters=num_parents, random_state=0).fit_predict(base_emb)
    unique_parents = np.unique(parent_labels)
    base_children = n_clusters // len(unique_parents)
    remainder = n_clusters % len(unique_parents)
    children_per_parent = {p: base_children for p in unique_parents}
    for i, p in enumerate(unique_parents):
        if i < remainder:
            children_per_parent[p] += 1

    final_labels = -np.ones(num_samples, dtype=int)
    current_label = 0

    for i, p in enumerate(unique_parents):
        mask = parent_labels == p
        idx_subset = np.where(mask)[0]
        if idx_subset.size == 0:
            continue
        x_subset = x[mask]
        if HAS_CUML_UMAP:
            local_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, x_subset.shape[0] - 1)), init="random", random_state=42)  # type: ignore
        else:
            local_reducer = PCA(n_components=min(n_components, x_subset.shape[1], x_subset.shape[0]), random_state=0)
        subset_emb = local_reducer.fit_transform(x_subset)

        k_child = max(1, min(children_per_parent[p], subset_emb.shape[0]))

        if k_child > 1:
            try:
                sub_labels = KMeans(n_clusters=k_child, n_init=10, random_state=0).fit_predict(subset_emb)
                
            except Exception:
                sub_labels = GaussianMixture(n_components=k_child, random_state=0).fit_predict(subset_emb)
        else:
            sub_labels = np.zeros(subset_emb.shape[0], dtype=int)

        for s in np.unique(sub_labels):
            label_indices = idx_subset[sub_labels == s]
            final_labels[label_indices] = current_label
            current_label += 1

    if np.any(final_labels < 0):
        final_labels[final_labels < 0] = max(0, current_label - 1)

    return final_labels, base_method, int(base_emb.shape[1])




def compute_unet_centroids_by_labels(unet_full_features: torch.Tensor, labels: np.ndarray, n_clusters: int):
    centroids = []
    for k in range(n_clusters):
        idx = np.where(labels == k)[0]
        if len(idx) == 0:
            centroids.append(np.zeros((unet_full_features.shape[1],), dtype=np.float32))
        else:
            centroids.append(unet_full_features[idx].mean(dim=0).numpy())
    return np.stack(centroids, axis=0)


def select_prototypes_full(initial_samples, full_features: torch.Tensor, centroids: np.ndarray, cluster_labels: np.ndarray):
    n_clusters = centroids.shape[0]
    prototypes = [None] * n_clusters
    for k in range(n_clusters):
        idx = np.where(cluster_labels == k)[0]
        if len(idx) == 0:
            prototypes[k] = initial_samples[0].detach().cpu()
            continue
        feats = full_features[idx].numpy()  # [m, D]
        centroid = centroids[k][None, :]     # [1, D]
        dists = np.linalg.norm(feats - centroid, axis=1)
        best = idx[np.argmin(dists)]
        prototypes[k] = initial_samples[best].detach().cpu()
    return prototypes


def unet_centroid_guided_sampling_mixed(x, noise_scheduler, pipeline, prototype=None, alpha=0.07, guidance_strength=0.2, target_timesteps_ratio=0.7):
    if prototype is not None and prototype.ndim == 3:
        prototype = prototype.unsqueeze(0)
    target_timesteps = int(len(noise_scheduler.timesteps) * target_timesteps_ratio)

    for i, t in enumerate(tqdm(noise_scheduler.timesteps, desc="Sampling (mixed)")):
        model_input = noise_scheduler.scale_model_input(x, t)
        with torch.no_grad():
            base_score = get_unet_sample(pipeline.unet(model_input, t))
            if i < target_timesteps and prototype is not None:
                proto_input = noise_scheduler.scale_model_input(prototype.to(x.device), t)
                proto_score = get_unet_sample(pipeline.unet(proto_input, t))
                mixed_score = (1.0 - alpha) * base_score + alpha * proto_score
                noise_pred = mixed_score
            else:
                noise_pred = base_score + torch.randn_like(base_score) * (guidance_strength * (1 - i / len(noise_scheduler.timesteps)))
            x = noise_scheduler.step(noise_pred, t, x).prev_sample
    return x


def visualize(reducer, clip_latents: torch.Tensor, cluster_labels: np.ndarray, centroid_usage, output_dir: str):
    try:
        os.makedirs(output_dir, exist_ok=True)
        pca_2 = PCA(n_components=2, random_state=0)
        xy = pca_2.fit_transform(clip_latents.numpy())
        plt.figure(figsize=(12, 8))
        plt.scatter(xy[:, 0], xy[:, 1], c=cluster_labels, s=10, alpha=0.6)
        plt.title('CLIP latents (PCA 2D)')
        plt.savefig(f"{output_dir}/clip_latents_pca2.png", dpi=200, bbox_inches='tight')
        plt.close()
    except Exception as e:
        print(f"Visualization failed: {e}")


def main(args):
    seed_everything(args.seed)

    pipeline = DDPMPipeline.from_pretrained(args.model_id).to(device)
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    noise_scheduler = pipeline.scheduler
    noise_scheduler.set_timesteps(args.num_inference_steps)

    full_features, initial_samples = generate_initial_samples_and_features(
        pipeline, noise_scheduler, args.num_init_samples, image_size=args.image_size, agg_method=args.agg
    )
    clip_latents = extract_clip_features_for_images(initial_samples, embed_model=args.embed_model)
    cluster_labels, reducer, reduced = reduce_and_cluster_latents(
        clip_latents, n_clusters=args.n_clusters
    )

    centroids = compute_unet_centroids_by_labels(full_features, cluster_labels, args.n_clusters)
    prototypes = select_prototypes_full(initial_samples, full_features, centroids, cluster_labels)

    os.makedirs(args.output_dir, exist_ok=True)
    generated_samples = []
    centroid_usage_count = {i: 0 for i in range(len(centroids))}
    sample_centroid_mapping = []

    per_centroid_target = max(1, args.target_samples // len(centroids))
    print(f"Target per centroid: {per_centroid_target}")

    for i in range(args.target_samples):
        usages = np.array([centroid_usage_count[c] for c in range(len(centroids))])
        ratios = usages / (per_centroid_target + 1e-8)
        k = int(np.argmin(ratios))
        x = torch.randn(1, 3, args.image_size, args.image_size, device=device)
        sample = unet_centroid_guided_sampling_mixed(
            x, noise_scheduler, pipeline, prototype=prototypes[k], alpha=args.alpha
        )
        generated_samples.append(sample)
        sample_centroid_mapping.append(k)
        centroid_usage_count[k] += 1
    count = 0
    for i, sample in enumerate(generated_samples):
        count += 1
        c_id = sample_centroid_mapping[i]
        cluster_name = f"cluster_{c_id:02d}"
        img_pil = ToPILImage()(sample.squeeze(0).cpu().clip(-1, 1) * 0.5 + 0.5)
        img_pil.save(f"{args.output_dir}/{cluster_name}_id_{c_id:02d}_sample_{count:04d}.png")

    print(f"Total images generated: {count}")
    print(f"Images saved to: {args.output_dir}")
    print("Completed full-feature prototype-based sampling.")


if __name__ == "__main__":
    args = parse_args()
    main(args)


