# var_proto_guided_pool.py
import os
import os.path as osp
import math
import random
from types import MethodType
from PIL import Image as PImage
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances
from sklearn.metrics import pairwise_distances_argmin_min
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
import os
import torch
from tqdm import tqdm
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
import argparse
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
from models import VQVAE, build_vae_var
import torch
import math
import torchvision.models as models
import torchvision.transforms as T
from sklearn.cluster import KMeans
import torch.nn.functional as F
irange = range
import math
from functools import partial
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import copy
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.vqvae import VQVAE, VectorQuantizer2




# NOTE:
# Some core implementation details are intentionally omitted in this version
# to preserve anonymity during the double-blind review process.
# The complete implementation will be publicly released upon paper acceptance.



# === user params / must set MODEL_DEPTH to one of allowed values ===
MODEL_DEPTH = 30
assert MODEL_DEPTH in {16, 20, 24, 30}
import numpy as np
from sklearn.cluster import KMeans

def cluster_selection(features: np.ndarray, k: int, seed: int = 42) -> np.ndarray:
    N = features.shape[0]
    if k >= N:
        return np.arange(N, dtype=np.int64)
    
    kmeans = KMeans(n_clusters=k, random_state=seed, n_init=10)
    labels = kmeans.fit_predict(features)     
    centers = kmeans.cluster_centers_        
    
    selected = []
    for ci in range(k):
        
        cluster_idx = np.where(labels == ci)[0]
        cluster_points = features[cluster_idx]
        dists = np.linalg.norm(cluster_points - centers[ci], axis=1)
        nearest = cluster_idx[np.argmin(dists)]
        selected.append(nearest)
    
    return np.array(selected, dtype=np.int64)

def kmax_selection(features: np.ndarray, k: int, seed: int = 42) -> np.ndarray:
    """
    Select k representative samples from features using medoid-style
    farthest-first selection (similar to your multiproto code).
    
    - If k >= N, return all indices.
    - Step1: compute centroid
    - Step2: pick closest sample to centroid (most representative)
    - Step3: farthest-first to increase diversity
    """
    N = features.shape[0]
    if k >= N:
        return np.arange(N, dtype=np.int64)
    
    # ---- Step 1: centroid ----
    centroid = features.mean(axis=0)
    dists = np.linalg.norm(features - centroid[None, :], axis=1)
    
    # First medoid: closest to centroid
    first_medoid = int(dists.argmin())
    selected = [first_medoid]
    
    # ---- Step 2: farthest-first selection ----
    # same as your multi-proto method, but purely on 'features' array
    while len(selected) < k:
        # compute dist to nearest selected medoid
        selected_feats = features[selected]  # (len(selected), D)
        # (N, 1, D) - (1, S, D) ➜ (N, S, D)
        dist_to_selected = np.linalg.norm(
            features[:, None, :] - selected_feats[None, :, :],
            axis=2
        )  # (N, S)
        
        # distance to nearest selected medoid
        min_dist = np.min(dist_to_selected, axis=1)
        
        # next medoid: farthest from existing ones
        next_medoid = int(min_dist.argmax())
        selected.append(next_medoid)
    
    return np.array(selected, dtype=np.int64)


# --------------------- utilities ---------------------
def fps_selection(features: np.ndarray, k: int, seed: int = 42) -> np.ndarray:
    N = features.shape[0]
    if k >= N:
        return np.arange(N, dtype=np.int64)
    rng = np.random.default_rng(seed)
    selected = [rng.integers(0, N)]
    dists = np.linalg.norm(features - features[selected[0]], axis=1)
    for _ in range(1, k):
        idx = int(np.argmax(dists))
        selected.append(idx)
        newd = np.linalg.norm(features - features[idx], axis=1)
        dists = np.minimum(dists, newd)
    return np.array(selected, dtype=np.int64)

def load_image_paths_from_dir(class_dir: str):
    imgs = []
    if not osp.isdir(class_dir):
        return imgs
    for nm in os.listdir(class_dir):
        if nm.lower().endswith(('jpg', 'jpeg', 'png', 'bmp')):
            imgs.append(osp.join(class_dir, nm))
    imgs.sort()
    return imgs

# ------------------- Guided AR sampler factory -------------------
def make_autoregressive_infer_cfg_guided_factory(feature_extractor, preprocess_tensor_fn, prototype_stage_feats_dict, args):
    """
    Returns a function to be bound to var instance.
    - feature_extractor: a CNN that maps normalized tensor [B,3,224,224] to [B,2048,1,1] (ResNet style).
    - preprocess_tensor_fn: callable(tensor_img) -> normalized tensor for feature_extractor
    - prototype_stage_feats_dict: dict[class_name] -> list of per-stage prototype features (or None)
    - args: contains proto_guidance_strength, proto_guidance_sigma, image_size, patch_nums
    """
    def autoregressive_infer_cfg_guided(self, pseudo_memory_c, B: int, label_B,
                                        g_seed: int = None, cfg=5, top_k=0, top_p=0.0,
                                        more_smooth=False, prototype_stage_feats=None,
                                        proto_guidance_strength: float = 3.0, proto_guidance_sigma: float = 0.15,
                                        return_hlist: bool = False):
  
        if g_seed is None: rng = None
        else: self.rng.manual_seed(g_seed); rng = self.rng
        
        if label_B is None:
            label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
        elif isinstance(label_B, int):
            label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
        
        sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes)), dim=0))
        
        lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
        next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
        
        cur_L = 0
        f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
        
        for b in self.blocks: b.attn.kv_caching(True)

        # If prototype_stage_feats provided, it should be list(len(patch_nums)) of 1D arrays (2048) or None
        use_proto = (prototype_stage_feats is not None)

        # prefetch embedding weights if needed
        emb_weight = None
        try:
            emb_weight = self.vae_quant_proxy[0].embedding.weight  # shape [V, Cvae]
        except Exception:
            emb_weight = None

        h_list_store = []

        for si, pn in enumerate(self.patch_nums):
            # ratio = si / self.num_stages_minus_1
            # cond_BD_or_gss = self.shared_ada_lin(cond_BD)
            # x = next_token_map
            # for b_block in self.blocks:
            #     x = b_block(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
            # logits_BlV = self.get_logits(x, cond_BD)  # [2B, L, V]
            # # classifier-free guidance mixing
            # t = cfg * ratio
            # logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]  # [B, L, V]
            ratio = si / self.num_stages_minus_1
            # last_L = cur_L
            cur_L += pn*pn
            # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
            cond_BD_or_gss = self.shared_ada_lin(cond_BD)
            x = next_token_map
            AdaLNSelfAttn.forward
            for b in self.blocks:
                x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
            logits_BlV = self.get_logits(x, cond_BD)
            
            t = cfg * ratio
            logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
            

          
    return autoregressive_infer_cfg_guided

# -------------------------- Main --------------------------
def main(args):
    # Setup
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # --- load classes (as in your original code) ---
    with open('./misc/class_indices.txt', 'r') as fp:
        all_classes = [x.strip() for x in fp.readlines()]
    if args.spec == 'woof':
        file_list = './misc/class_woof.txt'
    elif args.spec == 'nette':
        file_list = './misc/class_nette.txt'
    elif args.spec == 'all':
        file_list = './misc/class_indices.txt'
    
    else:
        file_list = './misc/class100.txt'
    with open(file_list, 'r') as fp:
        sel_classes = [x.strip() for x in fp.readlines()]

    phase = max(0, args.phase)
    cls_from = args.nclass * phase
    cls_to = args.nclass * (phase + 1)
    sel_classes = sel_classes[cls_from:cls_to]
    class_labels = [all_classes.index(s) for s in sel_classes]

    if args.ckpt is None:
        assert args.image_size in [256, 512]
        assert args.num_classes == 1000

    # load checkpoints
# Load model checkpoints (your existing logic)
    hf_home = '../var/resolve/main'
    vae_ckpt, var_ckpt = '../var_model/vae_ch160v4096z32.pth', f'../var_model/var_d{MODEL_DEPTH}.pth'
    if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
    if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')


    patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
    if 'vae' not in globals() or 'var' not in globals():
        vae, var = build_vae_var(
            V=4096, Cvae=32, ch=160, share_quant_resi=4,
            device=device, patch_nums=patch_nums,
            num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
        )

    print("MODEL_DEPTH", MODEL_DEPTH)

    vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
    var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
    vae.eval(); var.eval()
    for p in vae.parameters(): p.requires_grad_(False)
    for p in var.parameters(): p.requires_grad_(False)
    print('prepare finished.')

    # sampling params
    seed = args.seed
    cfg = getattr(args, 'cfg', 5)
    more_smooth = getattr(args, 'more_smooth', True)

    torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    tf32 = True
    torch.backends.cudnn.allow_tf32 = bool(tf32)
    torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
    torch.set_float32_matmul_precision("high" if tf32 else "highest")

    os.makedirs(args.save_dir, exist_ok=True)

    # ---------------- feature extractor & transforms ----------------
    resnet = models.resnet50(pretrained=True).to(device)
    resnet.eval()
    feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-1])  # returns [B,2048,1,1]
    # preprocess tensor function (expects image in [B,3,H,W] in [0,1])
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1,3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1,3,1,1)
    def preprocess_tensor_fn(tensor_img):
        # tensor_img: [B,3,H,W] in [0,1] float
        x = F.interpolate(tensor_img, size=(224,224), mode='bilinear', align_corners=False)
        x = (x - mean) / std
        return x

    # ---------------- prototype per-stage feature builder ----------------
    prototype_stage_feats_cache = {}  # class_name -> list of per-stage 2048-d vectors (or None)
    def build_multiproto_feats_for_class_cycle(prototype_dir, class_name, num_prototypes=5, sample_size=500):
        """
        - Randomly samples a subset of images for each prototype
        - For each subset, selects 1 representative prototype (medoid)
        - Computes per-stage ResNet features for each prototype
        - Returns list of length = num_prototypes, each element is list(len(patch_nums)) of np.float32 arrays (2048)
        """
        train_class_dir = osp.join(prototype_dir, class_name)
        paths = load_image_paths_from_dir(train_class_dir)
        if len(paths) == 0:
            return None

        tmp_preprocess = T.Compose([
            T.Resize(256), T.CenterCrop(224), T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])

        all_proto_feats = []

        for _ in range(num_prototypes):
           
            cur_paths = random.sample(paths, min(sample_size, len(paths)))

            feats_per_image = []
            pil_imgs = []
            for pth in cur_paths:
                pil = PImage.open(pth).convert('RGB')
                pil_imgs.append(pil)
                t = tmp_preprocess(pil).unsqueeze(0).to(device)
                with torch.no_grad():
                    feats = feature_extractor(t).view(1, -1).cpu().numpy()[0]
                feats_per_image.append(feats)
            feats_per_image = np.stack(feats_per_image)

           
            centroid = feats_per_image.mean(axis=0)
            dists = np.linalg.norm(feats_per_image - centroid[None, :], axis=1)
            medoid_idx = int(dists.argmin())
            proto_img = pil_imgs[medoid_idx]

           
            stage_feats = []
            for pn in patch_nums:
                scale = pn / float(patch_nums[-1])
                stage_size = max(16, int(round(args.image_size * scale)))
                img_rs = proto_img.resize((stage_size, stage_size), resample=PImage.BILINEAR)
                t = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])(img_rs).unsqueeze(0).to(device)
                t = (t - mean.to(t.device)) / std.to(t.device)
                with torch.no_grad():
                    feat = feature_extractor(t).view(-1).cpu().numpy()
                stage_feats.append(feat.astype(np.float32))
            all_proto_feats.append(stage_feats)

        return all_proto_feats
    def build_multiproto_feats_for_class_from_dir(prototype_dir, class_name, num_prototypes=1):
        """
        - Loads images in prototype_dir/class_name
        - Selects up to `num_prototypes` representative prototypes (medoids)
        - For each prototype, computes per-stage ResNet features
        - Returns list of length = num_prototypes, each element is list(len(patch_nums)) of np.float32 arrays (2048)
        """
        train_class_dir = osp.join(prototype_dir, class_name)
        paths = load_image_paths_from_dir(train_class_dir)
        if len(paths) == 0:
            return None

        max_search = min(len(paths), 2000)
        feats_per_image = []
        pil_imgs = []

        # --- Step 1: compute global features for medoid selection ---
        tmp_preprocess = T.Compose([
            T.Resize(256), T.CenterCrop(224), T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
        for pth in paths[:max_search]:
            pil = PImage.open(pth).convert('RGB')
            pil_imgs.append(pil)
            t = tmp_preprocess(pil).unsqueeze(0).to(device)
            with torch.no_grad():
                feats = feature_extractor(t).view(1, -1).cpu().numpy()[0]
            feats_per_image.append(feats)

        feats_per_image = np.stack(feats_per_image)
        centroid = feats_per_image.mean(axis=0)
        dists = np.linalg.norm(feats_per_image - centroid[None, :], axis=1)
        
        # --- Step 2: select multiple medoids (farthest-first to increase diversity) ---
        medoid_indices = [int(dists.argmin())]
        for _ in range(1, min(num_prototypes, len(paths))):
            # pick next medoid as farthest from existing ones
            dist_to_existing = np.min(
                np.linalg.norm(feats_per_image[:, None, :] - feats_per_image[medoid_indices][None, :, :], axis=2),
                axis=1
            )
            next_medoid = int(dist_to_existing.argmax())
            medoid_indices.append(next_medoid)

        # --- Step 3: compute per-stage features for each prototype ---
        all_proto_feats = []
        for midx in medoid_indices:
            proto_img = pil_imgs[midx]
            stage_feats = []
            for pn in patch_nums:
                scale = pn / float(patch_nums[-1])
                stage_size = max(16, int(round(args.image_size * scale)))
                img_rs = proto_img.resize((stage_size, stage_size), resample=PImage.BILINEAR)
                t = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])(img_rs).unsqueeze(0).to(device)
                t = (t - mean.to(t.device)) / std.to(t.device)
                with torch.no_grad():
                    feat = feature_extractor(t).view(-1).cpu().numpy()
                stage_feats.append(feat.astype(np.float32))
            all_proto_feats.append(stage_feats)

        return all_proto_feats
    def build_proto_feats_for_class_from_dir(prototype_dir, class_name):
        """
        - loads images in prototype_dir/class_name
        - compute features for each VAR stage by resizing proto to a stage-specific size then extracting ResNet features
        - returns list(len(patch_nums)) of 1D numpy arrays (2048) or None
        """
        train_class_dir = osp.join(prototype_dir, class_name)
        # train_class_dir = osp.join(class_dir, class_name)
        paths = load_image_paths_from_dir(train_class_dir)
        if len(paths) == 0:
            return None
        # limit num prototypes used for medoid selection
        max_search = min(len(paths), 500) #200
        feats_per_image = []
        pil_imgs = []
        for pth in paths[:max_search]:
            pil = PImage.open(pth).convert('RGB')
            pil_imgs.append(pil)
            # extract a global resnet feature at canonical size as fallback (we also compute per-stage below)
            tmp = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
            t = tmp(pil).unsqueeze(0).to(device)
            with torch.no_grad():
                feats = feature_extractor(t).view(1, -1).cpu().numpy()[0]
            feats_per_image.append(feats)
        feats_per_image = np.stack(feats_per_image)
        centroid = feats_per_image.mean(axis=0)
        medoid_idx = int(np.linalg.norm(feats_per_image - centroid[None,:], axis=1).argmin())
        # choose medoid image as prototype
        proto_img = pil_imgs[medoid_idx]

        # compute per-stage features
        stage_feats = []
        for pn in patch_nums:
            # stage image size approx: scale = pn / patch_nums[-1]
            scale = pn / float(patch_nums[-1])
            stage_size = max(16, int(round(args.image_size * scale)))
            # resize prototype to stage_size, then to 224 for resnet
            img_rs = proto_img.resize((stage_size, stage_size), resample=PImage.BILINEAR)
            # normalize and extract
            t = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor()])
            timg = t(img_rs).unsqueeze(0).to(device)
            timg = (timg - mean.to(timg.device)) / std.to(timg.device)
            with torch.no_grad():
                feat = feature_extractor(timg).view(-1).cpu().numpy()
            stage_feats.append(feat.astype(np.float32))
        return stage_feats
    
    @torch.no_grad()
    def teacher_filter_and_cluster(model_teacher, pool_images, pool_features, sel_class, class_label, args, n_right=100, seed=0):
       

        save_path = os.path.join(args.save_dir, sel_class)
        os.makedirs(save_path, exist_ok=True)
        for rank, idx in enumerate(selected_indices):
            img_arr = filtered_images[idx]
            img = PImage.fromarray(img_arr)
            filename = f"{class_label * 50 + rank:06d}.png"
            img.save(os.path.join(save_path, filename))

        print(f"Class {sel_class}: pool size {filtered_images.shape[0]}, selected {select_k} samples saved to {save_path}")

    print("All classes processed.")

    end = time.time()     

    print("time:", end - start, "s")


# -------------------- argument parsing and run --------------------
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256, dest='image_size')
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=5.0)
    parser.add_argument("--num-sampling-steps", type=int, default=50)
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--spec", type=str, default='all')
    parser.add_argument("--save-dir", type=str, default='./generated_samples_1220_0')
    parser.add_argument("--num-samples", type=int, default=100)
    parser.add_argument("--total-shift", type=int, default=0)
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--phase", type=int, default=0)
    parser.add_argument("--trick", action='store_true')
    parser.add_argument("--prototype-dir", type=str, default="/data/datasets/train", help="optional dir with per-class prototype images")
    parser.add_argument("--pool-size", type=int, default=1000) #1000
    parser.add_argument("--n-select-per-class", type=int, default=100) #select
    parser.add_argument("--batch-size", type=int, default=20)
    parser.add_argument("--proto-guidance-strength", type=float, default=3.0)
    parser.add_argument("--proto-guidance-sigma", type=float, default=0.15)
    parser.add_argument("--vae-ckpt", type=str, default=None)
    parser.add_argument("--var-ckpt", type=str, default=None)
    parser.add_argument("--arch-name", type=str, default="resnet18", help="arch name from pretrained torchvision models")
    args = parser.parse_args()

    main(args)
