import os

os.environ['OMP_NUM_THREADS'] = '5'
os.environ['OPENBLAS_NUM_THREADS'] = '5'
os.environ['MKL_NUM_THREADS'] = '5'
os.environ['VECLIB_MAXIMUM_THREADS'] = '5'
os.environ['NUMEXPR_NUM_THREADS'] = '5'
os.environ['OMP_DYNAMIC'] = 'false'
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,3"
import json
import time
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
import open_clip
from transformers import (
    AutoImageProcessor,
    AutoModel,
    ViTImageProcessor,
    ViTModel,
    ResNetForImageClassification,
    CLIPModel,
)

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics.cluster import (
    fowlkes_mallows_score,
    v_measure_score,
    adjusted_rand_score,
    normalized_mutual_info_score,
)
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
from tqdm.auto import tqdm
from sklearn.manifold import TSNE


try:
    import umap as _umap 
except Exception:
    _umap = None


try:
    from cuml.manifold.umap import UMAP as cuUMAP  
    HAS_CUML_UMAP = True
except Exception:
    cuUMAP = None  
    HAS_CUML_UMAP = False

torch.set_num_threads(5)   
torch.set_num_interop_threads(5)   



def seed_everything(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class CelebADataset(Dataset):
    def __init__(
        self,
        dataframe,
        attribute_cols: Tuple[str, str, str, str],
        root_dir: str,
        split_value: int,
        processor,
    ) -> None:
        self.data_frame = dataframe
        self.data_frame = self.data_frame[self.data_frame["split"] == split_value].reset_index(drop=True)
        self.root_dir = root_dir
        self.attr_cols = attribute_cols
        self.attr_arrays = [self.data_frame[c].values for c in attribute_cols]
        self.processor = processor

    def __len__(self) -> int:
        return len(self.data_frame)

    def __getitem__(self, idx: int):
        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")

        attrs = [int(a[idx]) for a in self.attr_arrays]
    
        try:
            proc = self.processor(image, return_tensors="pt")
            pixel_values = proc["pixel_values"].squeeze(0)  # [C,H,W]
        except TypeError:
            pv = self.processor(image)
            pixel_values = pv if isinstance(pv, torch.Tensor) else torch.as_tensor(pv)

        sample = {
            "pixel_values": pixel_values,
            "attrs": attrs,
        }
        return sample


def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Tuple[Dict[str, torch.Tensor], List[List[int]]]:
    pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0)  # [B,C,H,W]
    attrs = [b["attrs"] for b in batch]
    return {"pixel_values": pixel_values}, attrs


def load_model_and_processor(model_name: str):
    if model_name == "DINOV3":
        processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vit7b16-pretrain-lvd1689m")
        model = AutoModel.from_pretrained("facebook/dinov3-vit7b16-pretrain-lvd1689m", device_map="auto")
    elif model_name == "DINOV2":
        processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
        model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="auto")

    elif model_name == "OpenCLIP":
        model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
            "ViT-H-14", pretrained='laion2b_s32b_b79k'
        )
        if torch.cuda.is_available():
            model = model.to("cuda")
        processor = preprocess_val
    elif model_name == "ViT":
        processor = ViTImageProcessor.from_pretrained("google/vit-large-patch16-224-in21k")
        model = ViTModel.from_pretrained("google/vit-large-patch16-224-in21k", device_map="auto")
    elif model_name == "ResNet":
        processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
        model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50", device_map="auto")
    elif model_name == "CLIP":
        processor = AutoImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
        model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", device_map="auto")
    else:
        raise ValueError(f"Unsupported model_name: {model_name}")
    return processor, model


def get_pooled_features(outputs) -> torch.Tensor:
    if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
        return outputs.pooler_output
    if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
        lhs = outputs.last_hidden_state
        if lhs.ndim == 3 and lhs.size(1) >= 1:
            return lhs[:, 0]
    if hasattr(outputs, "image_embeds") and outputs.image_embeds is not None:
        return outputs.image_embeds
    if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
        lhs = outputs.last_hidden_state
        if lhs.ndim == 3:
            return lhs.mean(dim=1)
    raise ValueError("Unable to derive pooled features from model outputs")


def extract_features(
    dataloader: DataLoader,
    model,
    use_device_move: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    all_feats: List[torch.Tensor] = []
    attrs_1: List[int] = []
    attrs_2: List[int] = []
    attrs_3: List[int] = []
    attrs_4: List[int] = []
    i = 0
    model.eval()
    with torch.inference_mode():
        for batch in dataloader:
            images, attrs = batch
            if use_device_move:
                device = next(model.parameters()).device
                dtype = next(model.parameters()).dtype
                images = {k: v.to(device=device, dtype=dtype) for k, v in images.items()}
            if isinstance(model, CLIPModel):
                pooled = model.get_image_features(pixel_values=images["pixel_values"])  # [B,D]
            elif hasattr(model, "encode_image"):
                pooled = model.encode_image(images["pixel_values"])  # [B,D]
            elif isinstance(model, ResNetForImageClassification):
                outputs = model(**images, output_hidden_states=True, return_dict=True)
                if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None and len(outputs.hidden_states) > 0:
                    last_h = outputs.hidden_states[-1]
                    if last_h.ndim == 4:
                        pooled = F.adaptive_avg_pool2d(last_h, (1, 1)).flatten(1)
                    elif last_h.ndim == 3:
                        pooled = last_h.mean(dim=1)
                    else:
                        pooled = outputs.logits
                else:
                    pooled = outputs.logits
            else:
                outputs = model(**images)
                pooled = get_pooled_features(outputs)
            all_feats.append(pooled.detach().cpu())
            i += 1
            for a in attrs:
                attrs_1.append(int(a[0]))
                attrs_2.append(int(a[1]))
                attrs_3.append(int(a[2]))
                attrs_4.append(int(a[3]))
            if i > 1250:
                break

    features = torch.cat(all_feats, dim=0).numpy()
    return (
        features,
        np.asarray(attrs_1, dtype=np.int32),
        np.asarray(attrs_2, dtype=np.int32),
        np.asarray(attrs_3, dtype=np.int32),
        np.asarray(attrs_4, dtype=np.int32),
    )


def map_clusters_to_binary_labels(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    unique_vals = np.unique(y_true)
    if unique_vals.size != 2:
        mapping = {val: idx for idx, val in enumerate(sorted(unique_vals.tolist()))}
        y_true_bin = np.vectorize(mapping.get)(y_true)
    else:
        y_true_bin = y_true

    C = confusion_matrix(y_true_bin, y_pred)  # shape (2, k)
    if C.ndim != 2 or C.shape[0] != 2:
        if C.shape[0] == 1:
            assigned_class = int(np.unique(y_true_bin)[0])
            return np.full_like(y_true_bin, fill_value=assigned_class)
        return y_pred.copy()

    num_clusters = C.shape[1]
    row_ind, col_ind = linear_sum_assignment(-C)
    cluster_to_class = np.empty(num_clusters, dtype=int)
    cluster_to_class.fill(-1)
    for r, c in zip(row_ind, col_ind):
        cluster_to_class[c] = int(r)
    remaining = np.where(cluster_to_class == -1)[0]
    if remaining.size > 0:
        maj_for_remaining = (C[1, remaining] > C[0, remaining]).astype(int)
        cluster_to_class[remaining] = maj_for_remaining
    y_pred_bin = cluster_to_class[y_pred]
    return y_pred_bin.astype(int)


def cluster_full(latents: np.ndarray, n_clusters: int, method: str) -> np.ndarray:
    if method == "K-means":
        return KMeans(n_clusters=n_clusters, random_state=0).fit_predict(latents)
    if method == "GMM":
        return GaussianMixture(n_components=n_clusters, random_state=0).fit_predict(latents)
    raise ValueError(f"Unsupported full-space method: {method}")


def reduce_and_cluster(
    latents: np.ndarray,
    n_clusters: int,
    method: str,
    n_components: int = 64,
) -> Tuple[np.ndarray, str, int]:
    num_samples, num_features = latents.shape
    method_l = (method or "").lower()

    if method_l == "umap" and HAS_CUML_UMAP:
        target_dims = min(max(2, n_components), max(2, num_samples - 1))
        reducer = cuUMAP(n_components=target_dims, init="random", random_state=0)  # type: ignore
        reduced = reducer.fit_transform(latents)
        used_method = "umap"
        used_dims = int(reduced.shape[1])
    elif method_l == "pca":
        target_dims = min(n_components, num_features, num_samples)
        reducer = PCA(n_components=target_dims, random_state=0)
        reduced = reducer.fit_transform(latents)
        used_method = "pca"
        used_dims = int(reduced.shape[1])
    else:
        reduced = latents
        used_method = "none"
        used_dims = int(latents.shape[1])

    labels = KMeans(n_clusters=n_clusters, n_init=10, random_state=0).fit_predict(reduced)
    return labels, used_method, used_dims


def hierarchical_reduce_and_cluster(
    latents: np.ndarray,
    n_clusters: int,
    n_components: int = 64,
) -> Tuple[np.ndarray, str, int]:
    x = latents.astype(np.float32, copy=False)
    num_samples = x.shape[0]
    if n_clusters <= 1 or num_samples < max(2, n_clusters):
        if HAS_CUML_UMAP:
            base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
            base_method = "umap"
        else:
            base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
            base_method = "pca"
        base_emb = base_reducer.fit_transform(x)
        labels = KMeans(n_clusters=max(1, n_clusters), n_init=10, random_state=0).fit_predict(base_emb)
        return labels, base_method, int(base_emb.shape[1])

    if HAS_CUML_UMAP:
        base_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, num_samples - 1)), init="random", random_state=42)  # type: ignore
        base_emb = base_reducer.fit_transform(x)
        base_method = "umap"
    else:
        base_reducer = PCA(n_components=min(n_components, x.shape[1], num_samples), random_state=0)
        base_emb = base_reducer.fit_transform(x)
        base_method = "pca"

    num_parents = 2
    parent_labels = KMeans(n_clusters=num_parents, random_state=0).fit_predict(base_emb)
    unique_parents = np.unique(parent_labels)

    base_children = n_clusters // len(unique_parents)
    remainder = n_clusters % len(unique_parents)
    children_per_parent = {p: base_children for p in unique_parents}
    for i, p in enumerate(unique_parents):
        if i < remainder:
            children_per_parent[p] += 1

    final_labels = -np.ones(num_samples, dtype=int)
    current_label = 0

    for i, p in enumerate(unique_parents):
        mask = parent_labels == p
        idx_subset = np.where(mask)[0]
        if idx_subset.size == 0:
            continue
        x_subset = x[mask]
        if HAS_CUML_UMAP:
            local_reducer = cuUMAP(n_components=min(max(2, n_components), max(2, x_subset.shape[0] - 1)), init="random", random_state=42)  # type: ignore
        else:
            local_reducer = PCA(n_components=min(n_components, x_subset.shape[1], x_subset.shape[0]), random_state=0)
        subset_emb = local_reducer.fit_transform(x_subset)

        k_child = max(1, min(children_per_parent[p], subset_emb.shape[0]))

        if k_child > 1:
            try:
                sub_labels = KMeans(n_clusters=k_child, random_state=0).fit_predict(subset_emb)
            except Exception:
                sub_labels = GaussianMixture(n_components=k_child, random_state=0).fit_predict(subset_emb)
        else:
            sub_labels = np.zeros(subset_emb.shape[0], dtype=int)

        for s in np.unique(sub_labels):
            label_indices = idx_subset[sub_labels == s]
            final_labels[label_indices] = current_label
            current_label += 1

    if np.any(final_labels < 0):
        final_labels[final_labels < 0] = max(0, current_label - 1)

    return final_labels, base_method, int(base_emb.shape[1])


def run_experiments(
    csv_file: str,
    root_dir: str,
    batch_size: int = 8,
    splits_to_use: Tuple[int, ...] = (2,),
    attribute_cols: Tuple[str, str, str, str] = ("Male", "Young", "Eyeglasses", "Blond_Hair"),
    model_names: Tuple[str, ...] = ("DINOV3", "DINOV2", "ViT", "ResNet", "CLIP"),
    cluster_methods: Tuple[str, ...] = ("K-means", "GMM", "pca", "umap", "hierarchical"),
    cluster_counts: Tuple[int, ...] = (2, 4, 8, 16),
    reduction_dim: int = 64,
    output_json: Optional[str] = None,
) -> str:
    import pandas as pd

    seed_everything(0)

    if output_json is None:
        ts = time.strftime("%Y%m%d-%H%M%S")
        output_json = os.path.join(os.path.dirname(__file__), f"experiment_results_{ts}.json")

    df = pd.read_csv(csv_file)
    df = df.replace(-1, 0)

    results: List[Dict] = []

    for model_name in tqdm(model_names, desc="Models"):
        processor, model = load_model_and_processor(model_name)

        for split in tqdm(splits_to_use, desc=f"{model_name} splits", leave=False):
            dataset = CelebADataset(
                dataframe=df,
                attribute_cols=attribute_cols,
                root_dir=root_dir,
                split_value=split,
                processor=processor,
            )
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

            features, a1, a2, a3, a4 = extract_features(
                tqdm(dataloader, desc=f"Extract {model_name} split {split}", leave=False),
                model,
                use_device_move=True,
            )

            for method in tqdm(cluster_methods, desc=f"Methods {model_name} split {split}", leave=False):
                for k in tqdm(cluster_counts, desc=f"k for {method}", leave=False):
                    try:
                        reduction_method = "none"
                        reduction_dims = features.shape[1]
                        if method in ("K-means", "GMM"):
                            labels = cluster_full(features, n_clusters=k, method=method)
                        elif method in ("pca", "umap"):
                            labels, reduction_method, reduction_dims = reduce_and_cluster(
                                features,
                                n_clusters=k,
                                method=method,
                                n_components=reduction_dim,
                            )
                        elif method == "hierarchical":
                            labels, reduction_method, reduction_dims = hierarchical_reduce_and_cluster(
                                features,
                                n_clusters=k,
                                n_components=reduction_dim,
                            )
                        else:
                            raise ValueError(f"Unknown method: {method}")
                        if k > 2:
                            labels_m1 = map_clusters_to_binary_labels(a1, labels)
                            labels_m2 = map_clusters_to_binary_labels(a2, labels)
                            labels_m3 = map_clusters_to_binary_labels(a3, labels)
                            labels_m4 = map_clusters_to_binary_labels(a4, labels)
                        else:
                            labels_m1 = labels
                            labels_m2 = labels
                            labels_m3 = labels
                            labels_m4 = labels

                        metrics = {
                            "Male": {
                                "fmi": float(fowlkes_mallows_score(a1, labels_m1)),
                                "v_measure": float(v_measure_score(a1, labels_m1)),
                                "ari": float(adjusted_rand_score(a1, labels_m1)),
                                "nmi": float(normalized_mutual_info_score(a1, labels_m1)),
                            },
                            "Young": {
                                "fmi": float(fowlkes_mallows_score(a2, labels_m2)),
                                "v_measure": float(v_measure_score(a2, labels_m2)),
                                "ari": float(adjusted_rand_score(a2, labels_m2)),
                                "nmi": float(normalized_mutual_info_score(a2, labels_m2)),
                            },
                            "Eyeglasses": {
                                "fmi": float(fowlkes_mallows_score(a3, labels_m3)),
                                "v_measure": float(v_measure_score(a3, labels_m3)),
                                "ari": float(adjusted_rand_score(a3, labels_m3)),
                                "nmi": float(normalized_mutual_info_score(a3, labels_m3)),
                            },
                            "Blond_Hair": {
                                "fmi": float(fowlkes_mallows_score(a4, labels_m4)),
                                "v_measure": float(v_measure_score(a4, labels_m4)),
                                "ari": float(adjusted_rand_score(a4, labels_m4)),
                                "nmi": float(normalized_mutual_info_score(a4, labels_m4)),
                            },
                        }

                        results.append(
                            {
                                "model_name": model_name,
                                "num_samples": int(features.shape[0]),
                                "latent_dim": int(features.shape[1]),
                                "cluster_method": method,
                                "n_clusters": int(k),
                                "reduction_used": reduction_method,
                                "reduction_dims": int(reduction_dims),
                                "metrics": metrics,
                            }
                        )
                    except Exception as e:
                        results.append(
                            {
                                "model_name": model_name,
                                "split": int(split),
                                "cluster_method": method,
                                "n_clusters": int(k),
                                "error": str(e),
                            }
                        )

    with open(output_json, "w") as f:
        json.dump(results, f, indent=2)

    return output_json


if __name__ == "__main__":
    CSV_FILE = "/celeba/metadata.csv"
    ROOT_DIR = "/celeba/img_align_celeba/"
    

    out = run_experiments(
        csv_file=CSV_FILE,
        root_dir=ROOT_DIR,
        batch_size=8,
        splits_to_use=(2,),
        attribute_cols=("Male", "Young", "Eyeglasses", "Blond_Hair"),
        model_names=("OpenCLIP", "DINOV3", "DINOV2", "ViT", "ResNet", "CLIP"),
        cluster_methods=("K-means", "GMM", "pca", "umap", "hierarchical"),
        cluster_counts=(2, 4, 6, 8, 10),
        reduction_dim=32,
    )
    print(f"Saved results to: {out}")


