import copy
import os
import sys
import subprocess
import tempfile
import shutil
import glob
import random

import torch
from torch import distributed as dist
from torchdyn.core import NeuralODE

# from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid, save_image
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.cm as cm
from matplotlib.lines import Line2D
from matplotlib.offsetbox import OffsetImage, AnnotationBbox


def _format_imb_str(imb: float) -> str:
    """Format imbalance factor into compact string (e.g., 0.001 -> '0.001')."""
    s = f"{imb:.6f}".rstrip('0').rstrip('.')
    return s if s != "" else "0"


def find_model_directory(base_dir, target_dir_name):
    """
    Find a directory with exact name first, then look for _multigpu suffix version.
    
    Args:
        base_dir: Base directory to search in
        target_dir_name: Target directory name to find
    
    Returns:
        Full path of found directory
    
    Raises:
        FileNotFoundError: If neither exact folder nor _multigpu version is found
    """
    # First check if exact folder exists
    exact_path = os.path.join(base_dir, target_dir_name)
    if os.path.exists(exact_path):
        return exact_path
    
    # If exact folder doesn't exist, search for _multigpu version
    multigpu_dir_name = f"{target_dir_name}_multigpu"
    multigpu_path = os.path.join(base_dir, multigpu_dir_name)
    
    if os.path.exists(multigpu_path):
        print(f"Exact folder '{target_dir_name}' not found, but found _multigpu version: {multigpu_dir_name}")
        return multigpu_path
    
    # If neither exists, raise error
    raise FileNotFoundError(f"Neither '{target_dir_name}' nor '{multigpu_dir_name}' found in {base_dir}")


def lt_cache_dir(dataset_name: str, imb_factor: float, data_root: str = "./data") -> str:
    """Return cache directory path for LT datasets based on imb_factor."""
    imb_str = _format_imb_str(imb_factor)
    if dataset_name == "cifar10_lt":
        sub = f"cifar10_lt_imb{imb_str}"
    elif dataset_name == "cifar100_lt":
        sub = f"cifar100_lt_imb{imb_str}"
    else:
        raise ValueError("lt_cache_dir: dataset_name must be one of ['cifar10_lt','cifar100_lt']")
    return os.path.join(data_root, sub)


def ensure_lt_dataset_dir(dataset_name: str, imb_factor: float, split: str = "train",
                          data_root: str = "./data", per_class_subdir: bool = True) -> str:
    """Ensure LT dataset is materialized as images under a cache dir; return the dir.

    For cifar10_lt, this uses create_dataset_asfile.py to export images if missing.
    For cifar100_lt, prefer existing cache dir; generation can be added siexamplerly later.
    """
    out_dir = lt_cache_dir(dataset_name, imb_factor, data_root)
    if os.path.isdir(out_dir) and len(os.listdir(out_dir)) > 0:
        return out_dir

    # Materialize LT dataset to disk as ImageFolder structure (class subdirs) for both CIFAR10/100 LT
    os.makedirs(out_dir, exist_ok=True)

    # Build a minimal transform to get tensors in [0,1]
    tfm = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])

    if dataset_name == "cifar10_lt":
        from torchcfm.utils import ImbalanceCIFAR10 as _LT
        num_classes = 10
    elif dataset_name == "cifar100_lt":
        from torchcfm.utils import ImbalanceCIFAR100 as _LT
        num_classes = 100
    else:
        # Safety
        return out_dir

    ds = _LT(root=data_root, train=(split == "train"), download=True, transform=tfm, imb_factor=imb_factor)

    # Ensure class subdirectories
    if per_class_subdir:
        for c in range(num_classes):
            cls_dir = os.path.join(out_dir, f"{c:02d}")
            os.makedirs(cls_dir, exist_ok=True)

    saved = 0
    for i in range(len(ds)):
        img, label = ds[i]
        if per_class_subdir:
            dest_dir = os.path.join(out_dir, f"{int(label):02d}")
            os.makedirs(dest_dir, exist_ok=True)
            path = os.path.join(dest_dir, f"{split}_lt_{saved:06d}.png")
        else:
            path = os.path.join(out_dir, f"{split}_lt_{saved:06d}.png")
        save_image(img, path)
        saved += 1
    return out_dir


def get_real_dataset(dataset_name: str, split: str, transform, data_root: str = "./data",
                     imb_factor: float = None):
    """
    Construct dataset for evaluation/visualization in a consistent way.
    - cifar10/cifar100 → torchvision datasets
    - cifar10_lt/cifar100_lt → prefer cached ImageFolder under data_root/cifar*-lt_imb{imb};
      for cifar10_lt, auto-generate if missing via create_dataset_asfile.py.
    """
    if dataset_name == "cifar10":
        return datasets.CIFAR10(root=data_root, train=(split == "train"), download=True, transform=transform)
    if dataset_name == "cifar100":
        return datasets.CIFAR100(root=data_root, train=(split == "train"), download=True, transform=transform)

    if dataset_name in ("cifar10_lt", "cifar100_lt"):
        if imb_factor is None:
            raise ValueError("imb_factor must be provided for *_lt datasets")
        out_dir = ensure_lt_dataset_dir(dataset_name, imb_factor, split=split, data_root=data_root, per_class_subdir=True)
        # If dir has class subfolders, ImageFolder will work; otherwise, fallback to in-memory LT dataset
        has_subdirs = any(os.path.isdir(os.path.join(out_dir, d)) for d in os.listdir(out_dir) or [])
        if has_subdirs and len(os.listdir(out_dir)) > 0:
            return datasets.ImageFolder(root=out_dir, transform=transform)
        # Fallback to in-memory generation classes
        if dataset_name == "cifar10_lt":
            from torchcfm.utils import ImbalanceCIFAR10
            return ImbalanceCIFAR10(root=data_root, train=(split == "train"), download=True, transform=transform, imb_factor=imb_factor)
        else:
            from torchcfm.utils import ImbalanceCIFAR100
            return ImbalanceCIFAR100(root=data_root, train=(split == "train"), download=True, transform=transform, imb_factor=imb_factor)

    raise ValueError(f"Unknown dataset: {dataset_name}")


def setup(
    rank: int,
    total_num_gpus: int,
    master_addr: str = "localhost",
    master_port: str = "12355",
    backend: str = "nccl",
):
    """Initialize the distributed environment.

    Args:
        rank: Rank of the current process.
        total_num_gpus: Number of GPUs used in the job.
        master_addr: IP address of the master node.
        master_port: Port number of the master node.
        backend: Backend to use.
    """

    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port

    # initialize the process group
    dist.init_process_group(
        backend=backend,
        rank=rank,
        world_size=total_num_gpus,
    )


def generate_samples(model, parallel, savedir, step, net_="normal", device=None):
    """Save 64 generated images (8 x 8) for sanity check along training.

    Parameters
    ----------
    model:
        represents the neural network that we want to generate samples from
    parallel: bool
        represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU.
    savedir: str
        represents the path where we want to save the generated images
    step: int
        represents the current step of training
    """
    model.eval()

    model_ = copy.deepcopy(model)
    if parallel:
        # Send the models from GPU to CPU for inference with NeuralODE from Torchdyn
        model_ = model_.module.to(device)

    node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
    with torch.no_grad():
        traj = node_.trajectory(
            torch.randn(64, 3, 32, 32, device=device),
            t_span=torch.linspace(0, 1, 100, device=device),
        )
        traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
        traj = traj / 2 + 0.5
    save_image(traj, os.path.join(savedir, f"{net_}_generated_FM_images_step_{step}.png"), nrow=8)

    model.train()


def ema(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(
            target_dict[key].data * decay + source_dict[key].data * (1 - decay)
        )


def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x


def plot_fm_weights_histogram(fm_weight_tensor, weight_type, savedir, data_type="total", extra_info=""):
    """Plot and save histogram of FM weights for analysis.
    
    Parameters
    ----------
    fm_weight_tensor: torch.Tensor
        Tensor containing FM weights to analyze
    weight_type: str
        Type of weight used (e.g., 'none', 'inv_u', 'inv_v', 'inv_dv', 'inv_piT1')
    savedir: str
        Directory path where to save the histogram image
    """
    
    #  numpy  
    weights_np = fm_weight_tensor.cpu().numpy().flatten()
    
    #  
    min_val = np.min(weights_np)
    mean_val = np.mean(weights_np)
    max_val = np.max(weights_np)
    
    #  
    plt.figure(figsize=(10, 6))
    try:
        plt.hist(weights_np, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    except:
        plt.bar(weights_np, weights_np, width=0.001, alpha=0.7, color='skyblue', edgecolor='black')
    plt.axvline(mean_val, color='red', linestyle='--', label=f'Mean: {mean_val:.4f}')
    plt.axvline(min_val, color='green', linestyle='--', label=f'Min: {min_val:.4f}')
    plt.axvline(max_val, color='orange', linestyle='--', label=f'Max: {max_val:.4f}')
    
    #    
    stats_text = f'Min: {min_val:.4f}\nMean: {mean_val:.4f}\nMax: {max_val:.4f}'
    plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, 
             verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.xlabel('Weight Values')
    plt.ylabel('Frequency')
    plt.title(f'Distribution of FM Weights ({weight_type}_{data_type}_{extra_info})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    #  
    histogram_path = os.path.join(savedir, f'uot_wfm_weights_histogram_{weight_type}_{str(data_type)}.png')
    plt.savefig(histogram_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    #print(f"histogram saved: {histogram_path}")
    #print(f"Weight statistics - Min: {min_val:.4f}, Mean: {mean_val:.4f}, Max: {max_val:.4f}")
    
    return min_val, mean_val, max_val


def save_value_to_txt(value, filename, value_name="", extra_info_list=[]):
    with open(filename, "w") as f:
        f.write(f"{value_name}: {value}\n")
        for extra_info in extra_info_list:
            f.write(f"{extra_info}\n")


#### compute precision and recall ####
# written based on https://github.com/blandocs/improved-precision-and-recall-metric-pytorch/blob/master/functions.py

class ImageDataset(Dataset):
    def __init__(self, dir_path, data_size=100, batch_size=64, device=None):
        self.dir_path = dir_path
        self.device = device

        data_size = data_size - data_size%batch_size

        self.img_paths = []
        #     (  ),  
        exts = ('.png', '.jpg', '.jpeg', '.bmp')
        collected = []
        for r, _dirs, fnames in os.walk(dir_path):
            for nm in fnames:
                if nm.lower().endswith(exts):
                    collected.append(os.path.join(r, nm))
        collected.sort()
        #  
        self.img_paths = collected[:data_size]

        self.imsize = 224 # for vgg input size

        # torchvision ImageNet     (RGB, [0,1], mean/std )
        self.transformations = transforms.Compose([
            transforms.Resize((self.imsize, self.imsize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        image = self.transformations(image)
        return image.to(self.device, torch.float), img_path

    def __len__(self):
        return len(self.img_paths)

class feature_extractor(object):
    def __init__(self, generated_dir, real_dir, batch_size, data_size, device, model_pr, dataset_name='cifar10'):
        # parameters
        self.generated_dir = generated_dir
        self.real_dir = real_dir
        self.batch_size = batch_size
        self.data_size = data_size
        self.device = device
        self.model = model_pr
        self.dataset_name = dataset_name

    def extract(self):
        # test loading image properly
        # self.show_image(img)

        #    feature   
        head = None
        if self.model == "vgg16" or self.model is None:
            cnn = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
            # features + AdaptiveAvgPool + flatten + classifier[:5]
            head = nn.Sequential(
                cnn.features,
                nn.AdaptiveAvgPool2d((7, 7)),
                nn.Flatten(),
                *[cnn.classifier[i] for i in range(5)],
            )
        elif self.model == "resnet18":
            cnn = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            #  children   FC 
            head = nn.Sequential(*list(cnn.children())[:-1], nn.Flatten())
        elif self.model == "resnet50":
            cnn = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            head = nn.Sequential(*list(cnn.children())[:-1], nn.Flatten())
        elif self.model == "resnet101":
            cnn = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
            head = nn.Sequential(*list(cnn.children())[:-1], nn.Flatten())
        elif self.model == "inception_v3":
            cnn = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, aux_logits=False)
            # Inception AdaptiveAvgPool2d(1) + flatten  
            head = nn.Sequential(
                *list(cnn.children())[:-1],
                nn.Flatten(),
            )
        elif self.model == "densenet121":
            cnn = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
            head = nn.Sequential(cnn.features, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())
        elif self.model == "densenet169":
            cnn = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
            head = nn.Sequential(cnn.features, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())
        elif self.model == "densenet201":
            cnn = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
            head = nn.Sequential(cnn.features, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten())
        else:
            raise ValueError(f"Invalid model: {self.model}")

        head = head.to(self.device).eval()
        # summary(cnn, (3, 224, 224))

        generated_features = []
        real_features = []
        generated_img_paths = []

        with torch.no_grad():

            generated_data = ImageDataset(self.generated_dir, self.data_size, self.batch_size, device=self.device)
            generated_loader = DataLoader(generated_data, batch_size=self.batch_size, shuffle=False)

            for imgs, img_paths in tqdm(generated_loader, ncols=80):
                imgs = imgs.to(self.device, non_blocking=True)
                target_features = head(imgs)

                img_paths = list(img_paths)
                generated_img_paths.extend(img_paths)

                for target_feature in torch.chunk(target_features, target_features.size(0), dim=0):
                    generated_features.append(target_feature)

            real_data = ImageDataset(self.real_dir, self.data_size, self.batch_size, device=self.device)
            real_loader = DataLoader(real_data, batch_size=self.batch_size, shuffle=False)

            for imgs, _ in tqdm(real_loader, ncols=80):
                imgs = imgs.to(self.device, non_blocking=True)
                target_features = head(imgs)

                for target_feature in torch.chunk(target_features, target_features.size(0), dim=0):
                    real_features.append(target_feature)

        return generated_features, real_features, generated_img_paths

class precision_and_recall(object):
    def __init__(self, batch_size, data_size, device, ext_model_name, generated_dir, real_dir, dataset_name='cifar10'):
        self.batch_size = batch_size
        self.data_size = data_size
        self.k = 3
        self.device = device
        self.ext_model_name = ext_model_name
        self.generated_dir = generated_dir
        self.real_dir = real_dir
        self.dataset_name = dataset_name

    def run(self):
        
        # load data using vgg16
        extractor = feature_extractor(generated_dir=self.generated_dir, real_dir=self.real_dir, batch_size=self.batch_size, data_size=self.data_size, device=self.device, model_pr=self.ext_model_name, dataset_name=self.dataset_name)
        generated_features, real_features, _ = extractor.extract()
        # print(generated_features)
        # equal number of samples
        data_num = min(len(generated_features), len(real_features))
        print(f'data num: {data_num}')

        if data_num <= 0:
            print("there is no data")
            return
        generated_features = generated_features[:data_num]
        real_features = real_features[:data_num]

        # get precision and recall
        precision = self.manifold_estimate(real_features, generated_features, self.k)
        recall = self.manifold_estimate(generated_features, real_features, self.k)
 
        return precision, recall

    def manifold_estimate(self, A_features, B_features, k):
        #  CPU  1D    
        A_feats = [a.detach().to("cpu").view(-1) for a in A_features]
        B_feats = [b.detach().to("cpu").view(-1) for b in B_features]

        #  A k     (    )
        thresholds = [0.0] * len(A_feats)
        for a_idx, A in enumerate(tqdm(A_feats, ncols=80)):
            pairwise_distances = np.zeros(len(A_feats), dtype=np.float64)
            for i, A_prime in enumerate(A_feats):
                d = torch.norm(A - A_prime, p=2)       #  
                pairwise_distances[i] = float(d.item())  # numpy float 
            v = np.partition(pairwise_distances, k)[k]   #    
            thresholds[a_idx] = float(v)

        # B  A k-NN   (  float)
        n = 0
        for B in tqdm(B_feats, ncols=80):
            hit = False
            for a_idx, A_prime in enumerate(A_feats):
                d = torch.norm(B - A_prime, p=2)
                if float(d.item()) <= thresholds[a_idx]:
                    hit = True
                    break
            if hit:
                n += 1

        return n / len(B_feats)

def compute_pr_slow(fdir1=None, fdir2=None, gen=None, dataset_name=None, batch_size=None, dataset_res=None, num_gen=50000, dataset_split='train', real_data_path=None, device='cuda:0', ext_model_name='vgg16', imb_factor=None):
    """
    Compute precision and recall using the provided generator function.
    slow but known code.
    Parameters
    ----------
    fdir1: str
        Path to generated images directory
    fdir2: str
        Path to real images directory
    gen: function
        Function that generates images. Should take an unused latent parameter and return generated images.
    dataset_name: str
        Dataset name ("cifar10", "cifar100") or "custom" for custom image folder
    batch_size: int
        Batch size for processing
    dataset_res: int
        Resolution of the dataset images
    num_gen: int
        Number of generated images to use for evaluation
    dataset_split: str
        Dataset split ("train", "test")
    real_data_path: str, optional
        Path to real image folder (used when dataset_name is "custom")
    
    Returns
    -------
    tuple: (precision, recall)
        Precision and recall scores
    """
    
    # Create temporary directories for generated and real images
    temp_dir = tempfile.mkdtemp()
    generated_dir = os.path.join(temp_dir, "generated")
    real_dir = os.path.join(temp_dir, "real")
    os.makedirs(generated_dir, exist_ok=True)
    os.makedirs(real_dir, exist_ok=True)
    
    try:
        if fdir1 is None and gen is not None:
            # Generate images using the provided generator function
            print(f"Generating {num_gen} images...")
            num_batches = (num_gen + batch_size - 1) // batch_size
            generated_count = 0
            for i in tqdm(range(num_batches), desc="Generating images", ncols=80):
                batch_size_actual = min(batch_size, num_gen - generated_count)
                if batch_size_actual <= 0:
                    break
                    
                # Generate images using the provided function
                generated_images = gen(None)  # unused_latent parameter
                
                # Save generated images
                for j in range(batch_size_actual):
                    if generated_count >= num_gen:
                        break
                        
                    img_path = os.path.join(generated_dir, f"gen_{generated_count:06d}.png")
                    # Convert uint8 tensor to float tensor in range [0, 1] for save_image
                    img_tensor = generated_images[j:j+1].float() / 255.0
                    save_image(img_tensor, img_path)
                    generated_count += 1
        elif fdir1 is not None:
            print(f"Using generated images from {fdir1}")
            generated_dir = fdir1
        else:
            raise ValueError("fdir1 or gen must be provided")
                
        if fdir2 is None and dataset_name is not None:
            # Prepare real images
            print("Preparing real images...")
            if dataset_name in ["cifar10", "cifar100", "cifar10_lt", "cifar100_lt"]:
                transform = transforms.Compose([
                    transforms.Resize((dataset_res, dataset_res)),
                    transforms.ToTensor(),
                ])
                dataset = get_real_dataset(dataset_name, split=dataset_split, transform=transform, data_root="./data",
                                           imb_factor=(imb_factor if "lt" in dataset_name else None))
                
                # Save real images
                for i in range(min(num_gen, len(dataset))):
                    img, _ = dataset[i]
                    img_path = os.path.join(real_dir, f"real_{i:06d}.png")
                    save_image(img, img_path)
                    
            elif dataset_name == "custom" and real_data_path:
                # Use custom image folder
                
                # Get all image files from the real data path
                image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
                image_files = []
                for ext in image_extensions:
                    image_files.extend(glob.glob(os.path.join(real_data_path, ext)))
                    image_files.extend(glob.glob(os.path.join(real_data_path, ext.upper())))
                
                # Shuffle and take the required number of images
                random.shuffle(image_files)
                
                transform = transforms.Compose([
                    transforms.Resize((dataset_res, dataset_res)),
                    transforms.ToTensor(),
                ])
                
                for i, img_path in enumerate(image_files[:num_gen]):
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_tensor = transform(img)
                        save_path = os.path.join(real_dir, f"real_{i:06d}.png")
                        save_image(img_tensor, save_path)
                    except Exception as e:
                        print(f"Error processing {img_path}: {e}")
                        continue
            else:
                raise ValueError(f"Invalid dataset_name: {dataset_name}. Use 'cifar10', 'cifar100', or 'custom' with real_data_path")
        elif fdir2 is not None:
            print(f"Using real images from {fdir2}")
            real_dir = fdir2
        else:
            raise ValueError("fdir2 or dataset_name must be provided")
        
        # Compute Precision and Recall
        print("Computing precision and recall...")
        pr = precision_and_recall(batch_size=batch_size, data_size=num_gen, device=device, ext_model_name=ext_model_name, generated_dir=generated_dir, real_dir=real_dir, dataset_name=dataset_name)
        precision, recall = pr.run()
        
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        
        return precision, recall
        
    finally:
        # Clean up temporary directory
        shutil.rmtree(temp_dir, ignore_errors=True)

#####################


def compute_pr_fast(
    fdir1=None,
    fdir2=None,
    gen=None,
    dataset_name=None,
    batch_size=None,
    dataset_res=None,
    num_gen=50000,
    dataset_split='train',
    real_data_path=None,
    device='cuda:0',
    ext_model_name='vgg16',
    k: int = 3,
    top_m: int = 5,
    max_samples: int = None,
    use_faiss: bool = True,
    imb_factor=None,
):
    """Compute precision and recall faster using vectorized k-NN (FAISS or scikit-learn).

    Parameters mirror `compute_pr`. Differences:
    - k: k-th neighbor radius for manifold estimate
    - top_m: number of nearest neighbors to check per sample in cross-set membership
    - max_samples: optional cap on samples per set to bound runtime/memory
    - use_faiss: try FAISS (GPU if available) first, then fall back to scikit-learn
    """

    import numpy as np

    # Prepare temp dirs like compute_pr
    temp_dir = tempfile.mkdtemp()
    generated_dir = os.path.join(temp_dir, "generated")
    real_dir = os.path.join(temp_dir, "real")
    os.makedirs(generated_dir, exist_ok=True)
    os.makedirs(real_dir, exist_ok=True)

    try:
        # Prepare generated images
        if fdir1 is None and gen is not None:
            print(f"Generating {num_gen} images (fast path)...")
            num_batches = (num_gen + batch_size - 1) // batch_size
            generated_count = 0
            for _ in tqdm(range(num_batches), desc="Generating images", ncols=80):
                current_bs = min(batch_size, num_gen - generated_count)
                if current_bs <= 0:
                    break
                generated_images = gen(None)
                for j in range(current_bs):
                    if generated_count >= num_gen:
                        break
                    img_path = os.path.join(generated_dir, f"gen_{generated_count:06d}.png")
                    img_tensor = generated_images[j:j+1].float() / 255.0
                    save_image(img_tensor, img_path)
                    generated_count += 1
        elif fdir1 is not None:
            print(f"Using generated images from {fdir1}")
            generated_dir = fdir1
        else:
            raise ValueError("fdir1 or gen must be provided")

        # Prepare real images
        if fdir2 is None and dataset_name is not None:
            print("Preparing real images (fast path)...")
            transform = transforms.Compose([
                transforms.Resize((dataset_res, dataset_res)),
                transforms.ToTensor(),
            ])
            if dataset_name in ["cifar10", "cifar100", "cifar10_lt", "cifar100_lt"]:
                # leverage common real dataset loader
                dataset = get_real_dataset(dataset_name, split=dataset_split, transform=transform, data_root="./data",
                                            imb_factor=(imb_factor if "lt" in dataset_name else None))
            elif dataset_name == "custom" and real_data_path:
                #   
                image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
                image_files = []
                for r, _dirs, fnames in os.walk(real_data_path):
                    for nm in fnames:
                        if nm.lower().endswith(image_extensions):
                            image_files.append(os.path.join(r, nm))
                random.shuffle(image_files)
                for i, img_path in enumerate(image_files[:num_gen]):
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_tensor = transform(img)
                        save_path = os.path.join(real_dir, f"real_{i:06d}.png")
                        save_image(img_tensor, save_path)
                    except Exception as e:
                        print(f"Error processing {img_path}: {e}")
                        continue
                dataset = None
            else:
                raise ValueError(f"Invalid dataset_name: {dataset_name}")

            if dataset is not None:
                for i in range(min(num_gen, len(dataset))):
                    img, _ = dataset[i]
                    img_path = os.path.join(real_dir, f"real_{i:06d}.png")
                    save_image(img, img_path)
        elif fdir2 is not None:
            print(f"Using real images from {fdir2}")
            real_dir = fdir2
        else:
            raise ValueError("fdir2 or dataset_name must be provided")

        # Extract CNN features using existing extractor, then stack to tensors
        extractor = feature_extractor(
            generated_dir=generated_dir,
            real_dir=real_dir,
            batch_size=batch_size,
            data_size=num_gen,
            device=device,
            model_pr=ext_model_name,
            dataset_name=dataset_name if dataset_name is not None else 'cifar10',
        )
        gen_feats_list, real_feats_list, _ = extractor.extract()

        def _stack_feats(feats_list):
            feats = [f.detach().to("cpu").view(-1) for f in feats_list]
            if len(feats) == 0:
                return np.empty((0, 0), dtype=np.float32)
            arr = torch.stack(feats, dim=0).numpy().astype(np.float32, copy=False)
            return arr

        gen_arr = _stack_feats(gen_feats_list)
        real_arr = _stack_feats(real_feats_list)

        # Optional subsampling for speed/memory
        def _subsample(arr, limit):
            if limit is None or arr.shape[0] <= limit:
                return arr
            idx = np.random.choice(arr.shape[0], size=limit, replace=False)
            return arr[idx]

        gen_arr = _subsample(gen_arr, max_samples)
        real_arr = _subsample(real_arr, max_samples)

        if gen_arr.shape[0] == 0 or real_arr.shape[0] == 0:
            print("No features extracted; returning zeros.")
            return 0.0, 0.0

        # Helper: compute k-th NN radius (self set) and cross-set membership using FAISS or sklearn
        def _pr_membership(A, B):
            n_a = A.shape[0]
            # ensure k, top_m valid
            kk = min(k, max(1, n_a - 1))
            mm = min(top_m, n_a)

            thr = None
            try:
                import faiss  # type: ignore
                if use_faiss:
                    cpu_index = faiss.IndexFlatL2(A.shape[1])
                    # Prefer GPU index if available
                    try:
                        res = faiss.StandardGpuResources()
                        gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
                        gpu_index.add(A)
                        D_self, I_self = gpu_index.search(A, kk + 1)
                        thr = np.sqrt(D_self[:, kk])
                        D_cross, I_cross = gpu_index.search(B, mm)
                    except Exception:
                        cpu_index.add(A)
                        D_self, I_self = cpu_index.search(A, kk + 1)
                        thr = np.sqrt(D_self[:, kk])
                        D_cross, I_cross = cpu_index.search(B, mm)
                    thr_sq = thr[I_cross] ** 2
                    inside = (D_cross <= thr_sq)
                    hits = np.any(inside, axis=1)
                    return float(hits.mean())
            except Exception:
                pass

            # Fall back to scikit-learn
            from sklearn.neighbors import NearestNeighbors
            nbrs_self = NearestNeighbors(n_neighbors=kk + 1, algorithm='auto', metric='euclidean').fit(A)
            dist_self, _ = nbrs_self.kneighbors(A)
            thr = dist_self[:, kk]

            nbrs_cross = NearestNeighbors(n_neighbors=mm, algorithm='auto', metric='euclidean').fit(A)
            dist_cross, ind_cross = nbrs_cross.kneighbors(B)
            inside = dist_cross <= thr[ind_cross]
            hits = np.any(inside, axis=1)
            return float(hits.mean())

        precision = _pr_membership(real_arr, gen_arr)   # gen within real manifold
        recall = _pr_membership(gen_arr, real_arr)       # real within gen manifold

        print(f"Precision (fast): {precision:.4f}")
        print(f"Recall (fast): {recall:.4f}")
        return precision, recall

    finally:
        # Clean up temporary directory if we created it
        try:
            shutil.rmtree(temp_dir, ignore_errors=True)
        except Exception:
            pass


def compute_emd_2d(points1, points2):
    """
    2D   Earth Mover's Distance (Wasserstein distance) 
    
    Parameters
    ----------
    points1: np.array
            (N1 x 2)
    points2: np.array  
            (N2 x 2)
        
    Returns
    -------
    float: EMD 
    """
    try:
        from scipy.stats import wasserstein_distance
        import numpy as np
        
        if len(points1) == 0 or len(points2) == 0:
            return float('inf')
        
        # 2D     wasserstein distance   
        #  2D  1D  
        
        #  1:   EMD 
        emd_x = wasserstein_distance(points1[:, 0], points2[:, 0])
        emd_y = wasserstein_distance(points1[:, 1], points2[:, 1])
        return (emd_x + emd_y) / 2.0
        
    except ImportError:
        print("scipy not available, using simplified distance metric")
        # scipy      
        if len(points1) == 0 or len(points2) == 0:
            return float('inf')
        
        #     
        center1 = np.mean(points1, axis=0)
        center2 = np.mean(points2, axis=0)
        return np.linalg.norm(center1 - center2)
    
    except Exception as e:
        print(f"Error computing EMD: {e}")
        return float('inf')


def compute_emd_high_dim(features1, features2):
    """
       EMD  (     )
    
    Parameters
    ----------
    features1: np.array
            (N1 x D)
    features2: np.array  
            (N2 x D)
        
    Returns
    -------
    float: EMD 
    """
    try:
        from scipy.stats import wasserstein_distance
        import numpy as np
        
        if len(features1) == 0 or len(features2) == 0:
            return float('inf')
        
        #      EMD  
        #        
        
        #  1:   EMD  ( 10 )
        num_dims = min(10, features1.shape[1], features2.shape[1])  #  10 
        emd_values = []
        
        for dim in range(num_dims):
            emd_dim = wasserstein_distance(features1[:, dim], features2[:, dim])
            emd_values.append(emd_dim)
        
        return np.mean(emd_values)
        
    except ImportError:
        print("scipy not available, using simplified distance metric for high-dim")
        if len(features1) == 0 or len(features2) == 0:
            return float('inf')
        
        #     
        center1 = np.mean(features1, axis=0)
        center2 = np.mean(features2, axis=0)
        return np.linalg.norm(center1 - center2)
    
    except Exception as e:
        print(f"Error computing high-dim EMD: {e}")
        return float('inf')





def classify_generated_images(gen_dir, dataset_name, device, num_samples, return_confidence=False):
    """
    torch.hub    CIFAR    
    
    Parameters
    ----------
    gen_dir: str
          
    dataset_name: str
          ('cifar10', 'cifar10_lt', 'cifar100', 'cifar100_lt')
    device: str
        
    num_samples: int
          
        
    Returns
    -------
    np.array or (np.array, np.array):
        - return_confidence=False:   (shape: [N])
        - return_confidence=True: ( , confidence) . confidence softmax   (shape: [N])
    """
    try:
        from torchvision import transforms
        from torch.utils.data import DataLoader, Dataset
        
        # torch.hub    CIFAR  
        if dataset_name in ["cifar10", "cifar10_lt"]:
            print("Loading pretrained CIFAR-10 classifier from torch.hub...")
            model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_repvgg_a2", pretrained=True)
        elif dataset_name in ["cifar100", "cifar100_lt"]:
            print("Loading pretrained CIFAR-100 classifier from torch.hub...")
            model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_repvgg_a2", pretrained=True)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
        
        model = model.to(device)
        model.eval()
        print("✓ Successfully loaded pretrained classifier from torch.hub")
        
        # CIFAR-10/100   (32x32  )
        transform = transforms.Compose([
            transforms.Resize((32, 32)),  # CIFAR  
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        #    Dataset 
        class GeneratedImageDataset(Dataset):
            def __init__(self, img_dir, transform=None, max_samples=None):
                self.img_dir = img_dir
                self.transform = transform
                self.img_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                self.img_files.sort()
                if max_samples:
                    self.img_files = self.img_files[:max_samples]
                    
            def __len__(self):
                return len(self.img_files)
                
            def __getitem__(self, idx):
                img_path = os.path.join(self.img_dir, self.img_files[idx])
                img = Image.open(img_path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                return img
        
        #   
        dataset = GeneratedImageDataset(gen_dir, transform=transform, max_samples=num_samples)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
        
        #  
        predictions = []
        confidences = []
        with torch.no_grad():
            for batch in dataloader:
                batch = batch.to(device)
                outputs = model(batch)

                # softmax   confidence 
                probs = torch.softmax(outputs, dim=1)
                conf, predicted = torch.max(probs, 1)
                predictions.extend(predicted.cpu().numpy())
                confidences.extend(conf.cpu().numpy())

        if return_confidence:
            return np.array(predictions[:num_samples]), np.array(confidences[:num_samples])
        else:
            return np.array(predictions[:num_samples])
        
    except Exception as e:
        print(f"Error in classifying generated images: {e}")
        print("Falling back to random labels...")
        #     
        num_classes = 10 if dataset_name in ["cifar10", "cifar10_lt"] else 100
        if return_confidence:
            # confidence    0 
            return (
                np.random.randint(0, num_classes, size=num_samples),
                np.zeros(num_samples, dtype=float),
            )
        else:
            return np.random.randint(0, num_classes, size=num_samples)


def compute_pr(fdir1=None, fdir2=None, gen=None, dataset_name=None, batch_size=None, dataset_res=None, num_gen=50000, dataset_split='train', real_data_path=None, device='cuda:0', ext_model_name='vgg16', method="fast"):
    """
    Compute precision and recall using the provided generator function.
    """
    if method == "fast":
        return compute_pr_fast(fdir1, fdir2, gen, dataset_name, batch_size, dataset_res, num_gen, dataset_split, real_data_path, device, ext_model_name)
    else:
        return compute_pr_slow(fdir1, fdir2, gen, dataset_name, batch_size, dataset_res, num_gen, dataset_split, real_data_path, device, ext_model_name)


def visualize_pca_comparison(gen_dir=None, real_dir=None, gen_func=None, dataset_name=None, 
                           num_samples=5000, device='cuda:0', save_path=None, step=None, 
                           ext_model_name='vgg16', batch_size=1024, dataset_res=32, dataset_split='train',
                           pca_mode='feature', max_images_per_class=3, imb_factor=None):
    """
    PCA       2 
    
    Parameters
    ----------
    gen_dir: str, optional
            
    real_dir: str, optional  
            
    gen_func: function, optional
           (gen_dir None  )
    dataset_name: str
          ('cifar10', 'cifar100' )
    num_samples: int
        PCA   
    device: str
         
    save_path: str
           
    step: int
           ( )
    ext_model_name: str
            (pca_mode='feature'  )
    batch_size: int
         
    dataset_res: int
         
    dataset_split: str
          ('train', 'test')
    pca_mode: str
        PCA : 'feature' (VGG16  )  'raw_pixel' (   )
    max_images_per_class: int
               (: 3)
        
    Returns
    -------
    tuple: (explained_variance_ratio, save_file_path)
        PCA     
    """
    try:
        from sklearn.decomposition import PCA
    except ImportError as e:
        print(f"Required library not found: {e}")
        print("Please install scikit-learn: pip install scikit-learn")
        return None, None
    
    # PCA  
    if pca_mode not in ['feature', 'raw_pixel']:
        raise ValueError(f"Invalid pca_mode: {pca_mode}. Must be 'feature' or 'raw_pixel'")
    
    print(f"PCA mode: {pca_mode}")
    if pca_mode == 'feature':
        print(f"Using {ext_model_name} features for PCA")
    else:
        print(f"Using raw image pixels for PCA")
    
    # num_samples :       
    if num_samples <= 0:
        print(f"num_samples ({num_samples}) is <= 0, will use all available data")
        num_samples = None  # None     
    
    #   
    temp_dir = tempfile.mkdtemp()
    temp_gen_dir = os.path.join(temp_dir, "generated")
    temp_real_dir = os.path.join(temp_dir, "real")
    os.makedirs(temp_gen_dir, exist_ok=True)
    os.makedirs(temp_real_dir, exist_ok=True)
    
    try:
        #         
        # (   actual_sample_size    )
        temp_gen_samples = num_samples  #  
        
        #   
        if gen_dir is None and gen_func is not None:
            #        
            actual_gen_dir = temp_gen_dir
            generate_images_later = True
        elif gen_dir is not None:
            #    
            print(f"Using generated images from {gen_dir}")
            actual_gen_dir = gen_dir
            generate_images_later = False
        else:
            raise ValueError("gen_dir or gen_func must be provided")
            

            
        #      
        real_labels = None
        if real_dir is None and dataset_name is not None:
            print("Preparing real images for PCA...")
            if dataset_name in ["cifar10", "cifar100", "cifar10_lt", "cifar100_lt"]:
                transform = transforms.Compose([
                    transforms.Resize((dataset_res, dataset_res)),
                    transforms.ToTensor(),
                ])
                dataset = get_real_dataset(dataset_name, split=dataset_split, transform=transform, data_root="./data",
                                           imb_factor=(imb_factor if "lt" in dataset_name else None))
                
                #       ( )
                real_labels = []
                total_samples = len(dataset)
                
                # num_samples None      
                if num_samples is None or num_samples > total_samples:
                    sample_size = total_samples
                    print(f"Using all {total_samples} real images (num_samples was {num_samples})")
                else:
                    sample_size = num_samples
                    print(f"Using {sample_size} out of {total_samples} real images")
                
                #     
                random_indices = np.random.choice(total_samples, size=sample_size, replace=False)
                
                print(f"Randomly sampling {sample_size} images from {total_samples} total images")
                
                for i, idx in enumerate(random_indices):
                    img, label = dataset[idx]
                    img_path = os.path.join(temp_real_dir, f"real_{i:06d}.png")
                    save_image(img, img_path)
                    real_labels.append(label)
                real_labels = np.array(real_labels)
                
                #   
                unique_labels, counts = np.unique(real_labels, return_counts=True)
                print("Real data class distribution after random sampling:")
                for label, count in zip(unique_labels, counts):
                    print(f"  Class {label}: {count} samples")
                
                #      (    )
                actual_sample_size = sample_size
                    
            elif dataset_name == "custom":
                raise ValueError("For custom dataset, real_dir must be provided")
            else:
                raise ValueError(f"Unsupported dataset: {dataset_name}")
                
            actual_real_dir = temp_real_dir
        elif real_dir is not None:
            print(f"Using real images from {real_dir}")
            actual_real_dir = real_dir
            
            #      (  )
            def _list_all_images(root_dir):
                exts = ('.png', '.jpg', '.jpeg')
                files = []
                for r, _dirs, fnames in os.walk(root_dir):
                    for nm in fnames:
                        if nm.lower().endswith(exts):
                            files.append(os.path.join(r, nm))
                files.sort()
                return files
            img_files = _list_all_images(real_dir)
            total_real_files = len(img_files)
            
            # num_samples None      
            if num_samples is None or num_samples > total_real_files:
                actual_sample_size = total_real_files
                print(f"Using all {total_real_files} real images from external dir (num_samples was {num_samples})")
            else:
                actual_sample_size = num_samples
                print(f"Using {actual_sample_size} out of {total_real_files} real images from external dir")

            #       
            real_selected_files = []
            real_labels = None
            if actual_sample_size > 0:
                sel_idx = np.random.choice(total_real_files, size=actual_sample_size, replace=False)
                sel_idx.sort()
                real_selected_files = [img_files[i] for i in sel_idx]
                labels = []
                for pth in real_selected_files:
                    cls_dir = os.path.basename(os.path.dirname(pth))
                    #  (: 00, 01 ...) ;  best-effort int  
                    try:
                        labels.append(int(cls_dir))
                    except Exception:
                        labels.append(-1)
                real_labels = np.array(labels)
        else:
            raise ValueError("real_dir or dataset_name must be provided")
        
        #       
        if generate_images_later:
            print(f"Generating {actual_sample_size} images to match real data count...")
            num_batches = (actual_sample_size + batch_size - 1) // batch_size
            generated_count = 0
            
            for i in range(num_batches):
                current_batch_size = min(batch_size, actual_sample_size - generated_count)
                if current_batch_size <= 0:
                    break
                    
                generated_images = gen_func(None)
                
                for j in range(current_batch_size):
                    if generated_count >= actual_sample_size:
                        break
                    img_path = os.path.join(temp_gen_dir, f"gen_{generated_count:06d}.png")
                    img_tensor = generated_images[j:j+1].float() / 255.0
                    save_image(img_tensor, img_path)
                    generated_count += 1
            
            print(f"Generated {generated_count} images")
        
        #  
        if pca_mode == 'feature':
            print("Extracting CNN features for PCA...")
            extractor = feature_extractor(
                generated_dir=actual_gen_dir,
                real_dir=actual_real_dir,
                batch_size=batch_size,
                data_size=actual_sample_size,
                device=device,
                model_pr=ext_model_name,
                dataset_name=dataset_name if dataset_name is not None else 'cifar10'
            )
            
            gen_features, real_features, _ = extractor.extract()
            
            #  numpy  
            if len(gen_features) == 0 or len(real_features) == 0:
                print("No features extracted. Skipping PCA visualization.")
                return None, None
                
            gen_array = torch.stack([f.detach().cpu().view(-1) for f in gen_features]).numpy()
            real_array = torch.stack([f.detach().cpu().view(-1) for f in real_features]).numpy()
            
        elif pca_mode == 'raw_pixel':
            print("Loading raw pixel data for PCA...")
            
            #    (  +    )
            def load_images_from_dir(img_dir, max_samples, file_list=None):
                images = []
                selected_files = []
                if file_list is not None:
                    selected_files = file_list[:max_samples]
                    print(f"Loading {len(selected_files)} images from provided file list")
                else:
                    # /      
                    exts = ('.png', '.jpg', '.jpeg')
                    img_files = []
                    for r, _dirs, fnames in os.walk(img_dir):
                        for nm in fnames:
                            if nm.lower().endswith(exts):
                                img_files.append(os.path.join(r, nm))
                    img_files.sort()  #  
                    total_files = len(img_files)
                    sample_size = min(max_samples, total_files)
                    if sample_size > 0:
                        random_indices = np.random.choice(total_files, size=sample_size, replace=False)
                        selected_files = [img_files[i] for i in random_indices]
                        print(f"Randomly sampling {sample_size} images from {total_files} total files in {img_dir}")

                for img_path in selected_files:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img = img.resize((dataset_res, dataset_res))
                        img_array = np.array(img).astype(np.float32) / 255.0
                        img_flat = img_array.flatten()
                        images.append(img_flat)
                    except Exception as e:
                        print(f"Error loading {img_path}: {e}")
                        continue
                        
                return np.array(images) if images else np.empty((0, 0))
            
            #      (   )
            gen_array = load_images_from_dir(actual_gen_dir, actual_sample_size)
            #  real_dir  ,   real_selected_files  
            try:
                real_array = load_images_from_dir(actual_real_dir, actual_sample_size, file_list=real_selected_files)
            except NameError:
                # real_selected_files    
                real_array = load_images_from_dir(actual_real_dir, actual_sample_size)
            
            if gen_array.shape[0] == 0 or real_array.shape[0] == 0:
                print("No images loaded. Skipping PCA visualization.")
                return None, None
                
            # raw_pixel    
            #     None 
        
        print(f"Generated features shape: {gen_array.shape}")
        print(f"Real features shape: {real_array.shape}")
        
        #    
        if gen_array.shape[0] == real_array.shape[0]:
            print(f"✓ Sample counts match: {gen_array.shape[0]} generated = {real_array.shape[0]} real")
        else:
            print(f"⚠ Sample count mismatch: {gen_array.shape[0]} generated ≠ {real_array.shape[0]} real")
        
        #    torch.hub   
        gen_labels = None
        if dataset_name in ["cifar10", "cifar100", "cifar10_lt", "cifar100_lt"]:
            print("Classifying generated images using torch.hub pretrained classifier...")
            gen_labels = classify_generated_images(actual_gen_dir, dataset_name, device, actual_sample_size)
        
        #    (  )
        if dataset_name in ["cifar10", "cifar10_lt"]:
            num_classes = 10
        elif dataset_name in ["cifar100", "cifar100_lt"]:
            num_classes = 100
        else:
            num_classes = 10  # 
            
        #  real_dir      real_labels 
        
        print(f"Number of classes for visualization: {num_classes}")
        
        #   
        print(f"gen_labels is None: {gen_labels is None}")
        print(f"real_labels is None: {real_labels is None}")
        if gen_labels is not None:
            print(f"gen_labels shape: {gen_labels.shape}, unique values: {np.unique(gen_labels)}")
        if real_labels is not None:
            print(f"real_labels shape: {real_labels.shape}, unique values: {np.unique(real_labels)}")
        
        # PCA  (8 )
        print("Performing PCA...")
        combined_features = np.concatenate([gen_array, real_array])
        n_components = min(8, combined_features.shape[1], combined_features.shape[0])  #  8,    
        pca = PCA(n_components=n_components)
        pca.fit(combined_features)
        
        gen_pca = pca.transform(gen_array)
        real_pca = pca.transform(real_array)
        
        #     
        def load_sample_images(img_dir, indices, max_samples=None):
            """    """
            sample_images = []
            img_files = [f for f in os.listdir(img_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            img_files.sort()
            
            #     
            valid_indices = [idx for idx in indices if idx < len(img_files)]
            if max_samples is not None:
                valid_indices = valid_indices[:max_samples]
            
            for idx in valid_indices:
                try:
                    img_path = os.path.join(img_dir, img_files[idx])
                    img = Image.open(img_path).convert('RGB')
                    img = img.resize((32, 32))  # CIFAR  
                    sample_images.append((idx, np.array(img)))
                except Exception as e:
                    print(f"Error loading sample image {img_path}: {e}")
                    continue
            return sample_images
        
        def select_representative_points(pca_coords, labels, max_per_class=3):
            """     (   )"""
            selected_indices = {}
            
            if labels is not None:
                unique_labels = np.unique(labels)
                for label in unique_labels:
                    mask = labels == label
                    class_coords = pca_coords[mask]
                    class_indices = np.where(mask)[0]
                    
                    if len(class_coords) > 0:
                        #    (   )
                        center = np.mean(class_coords, axis=0)
                        distances = np.linalg.norm(class_coords - center, axis=1)
                        
                        #      
                        sorted_indices = np.argsort(distances)[::-1]  #   
                        selected = sorted_indices[:max_per_class]
                        selected_indices[label] = class_indices[selected]
            else:
                #     
                total_samples = min(max_per_class, len(pca_coords))
                selected_indices[0] = np.random.choice(len(pca_coords), size=total_samples, replace=False)
            
            return selected_indices
        
        def select_gen_points_near_real(gen_pca_coords, gen_labels, real_selected_indices, real_pca_coords, max_per_class=3):
            """         """
            selected_indices = {}
            
            if gen_labels is not None and len(real_selected_indices) > 0:
                #     
                real_selected_coords = []
                for label, indices in real_selected_indices.items():
                    for idx in indices:
                        if idx < len(real_pca_coords):
                            real_selected_coords.append(real_pca_coords[idx])
                
                if len(real_selected_coords) > 0:
                    real_selected_coords = np.array(real_selected_coords)
                    
                    unique_labels = np.unique(gen_labels)
                    for label in unique_labels:
                        mask = gen_labels == label
                        class_coords = gen_pca_coords[mask]
                        class_indices = np.where(mask)[0]
                        
                        if len(class_coords) > 0:
                            #          
                            min_distances = []
                            for gen_coord in class_coords:
                                distances_to_real = np.linalg.norm(real_selected_coords - gen_coord, axis=1)
                                min_distances.append(np.min(distances_to_real))
                            
                            min_distances = np.array(min_distances)
                            
                            #      
                            sorted_indices = np.argsort(min_distances)  #   
                            selected = sorted_indices[:max_per_class]
                            selected_indices[label] = class_indices[selected]
                else:
                    #       
                    selected_indices = select_representative_points(gen_pca_coords, gen_labels, max_per_class)
            else:
                #         
                selected_indices = select_representative_points(gen_pca_coords, gen_labels, max_per_class)
            
            return selected_indices
        
        #  PC   
        print("Creating PCA visualizations...")
        
        # PC    (1-2, 3-4, 5-6, 7-8)
        pc_pairs = [(0, 1), (2, 3), (4, 5), (6, 7)]
        pc_names = ['PC1-PC2', 'PC3-PC4', 'PC5-PC6', 'PC7-PC8']
        
        saved_files = []
        
        for pair_idx, (pc_x, pc_y) in enumerate(pc_pairs):
            #  PC  
            if pc_x >= n_components or pc_y >= n_components:
                print(f"Skipping {pc_names[pair_idx]} - not enough components (only {n_components} available)")
                continue
                
            plt.figure(figsize=(14, 10))
            
            #    (  )
            colors = cm.tab10(np.linspace(0, 1, min(num_classes, 10)))  #  10 
            if num_classes > 10:
                colors = cm.tab20(np.linspace(0, 1, min(num_classes, 20)))  #    
            #   real/gen  : real , gen 
            def _darken_rgb(rgb, factor=0.6):
                base = np.array(rgb[:3])
                return tuple(np.clip(base * factor, 0.0, 1.0))
            def _brighten_rgb(rgb, factor=0.6):
                base = np.array(rgb[:3])
                bright = 1.0 - (1.0 - base) * factor
                return tuple(np.clip(bright, 0.0, 1.0))
            real_colors = [ _darken_rgb(c, 0.6) for c in colors ]
            gen_colors  = [ _brighten_rgb(c, 0.6) for c in colors ]
            
            #    :     
            print(f"Drawing all classes visualization for {pc_names[pair_idx]}")
            print(f"gen_labels available: {gen_labels is not None}")
            if gen_labels is not None:
                print(f"gen_labels distribution: {np.bincount(gen_labels, minlength=num_classes)}")
            
            for class_idx in range(num_classes):
                #  
                real_mask = real_labels == class_idx
                if np.any(real_mask):
                    plt.scatter(real_pca[real_mask, pc_x], real_pca[real_mask, pc_y], 
                               alpha=0.7, s=30, c=[real_colors[class_idx % len(real_colors)]], 
                               marker='o', label=f'Real Class {class_idx}' if class_idx < 10 else None,
                               edgecolors='black', linewidth=0.5)
                
                #   (classifier  )
                if gen_labels is not None:
                    gen_mask = gen_labels == class_idx
                    if np.any(gen_mask):
                        plt.scatter(gen_pca[gen_mask, pc_x], gen_pca[gen_mask, pc_y], 
                                   alpha=0.7, s=30, c=[gen_colors[class_idx % len(gen_colors)]], 
                                   marker='^', label=f'Gen Class {class_idx}' if class_idx < 10 else None,
                                   edgecolors='black', linewidth=0.5)
                        print(f"  -> Added {np.sum(gen_mask)} generated points for class {class_idx} in all-classes view")
            
            #       
            if gen_labels is None:
                print("  -> gen_labels is None, showing all generated data in gray")
                plt.scatter(gen_pca[:, pc_x], gen_pca[:, pc_y], 
                           alpha=0.7, s=30, c='gray', 
                           marker='^', label='Generated Data',
                           edgecolors='black', linewidth=0.5)
            else:
                print(f"  -> Total generated points classified: {len(gen_labels)}")
            
            plt.xlabel(f'PC{pc_x+1} ({pca.explained_variance_ratio_[pc_x]:.1%} variance)')
            plt.ylabel(f'PC{pc_y+1} ({pca.explained_variance_ratio_[pc_y]:.1%} variance)')
            
            # PCA    
            if pca_mode == 'feature':
                title = f'PCA Visualization {pc_names[pair_idx]} ({ext_model_name} Features): Class-colored, Shape-coded'
            else:
                title = f'PCA Visualization {pc_names[pair_idx]} (Raw Pixels): Class-colored, Shape-coded'
                
            if dataset_name:
                title += f' ({dataset_name})'
            if step is not None:
                title += f', step {step}'
            plt.title(title, fontsize=14)
            
            #  
            
            legend_elements = []
            
            #    ( 10 )
            if dataset_name in ["cifar10", "cifar10_lt"]:
                class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                              'dog', 'frog', 'horse', 'ship', 'truck']
            elif dataset_name in ["cifar100", "cifar100_lt"]:
                class_names = [f'Class {i}' for i in range(min(10, num_classes))]
            else:
                class_names = [f'Class {i}' for i in range(min(10, num_classes))]
            
            #    
            legend_elements.extend([
                Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', 
                       markersize=8, label='Real Data', markeredgecolor='black'),
                Line2D([0], [0], marker='^', color='w', markerfacecolor='gray', 
                       markersize=8, label='Generated Data', markeredgecolor='black')
            ])
            
            #   
            for i in range(min(10, num_classes)):
                legend_elements.append(
                    Line2D([0], [0], marker='s', color='w', 
                           markerfacecolor=colors[i % len(colors)], markersize=6, 
                           label=f'{class_names[i]}' if i < len(class_names) else f'Class {i}',
                           markeredgecolor='black', markeredgewidth=0.5)
                )
            
            plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3)
            
            #    
            try:
                #      ( N)
                if real_labels is not None:
                    real_selected = select_representative_points(real_pca[:, [pc_x, pc_y]], real_labels, max_per_class=1)  #  1
                    real_sample_images = []
                    for label, indices in real_selected.items():
                        if len(indices) > 0:
                            #  N   (  )
                            if len(real_sample_images) < max_images_per_class:
                                sample_imgs = load_sample_images(actual_real_dir, indices[:1], max_images_per_class)  #  1
                                real_sample_images.extend(sample_imgs)
                    
                    #   plot 
                    for idx, (orig_idx, img_array) in enumerate(real_sample_images):
                        if orig_idx < len(real_pca):
                            point_x, point_y = real_pca[orig_idx, pc_x], real_pca[orig_idx, pc_y]
                            
                            #    (2 )
                            img_size = 0.3  # plot  30% ( 15% 2)
                            imagebox = OffsetImage(img_array, zoom=img_size)
                            
                            #      
                            offset_distance = 0.3  #    
                            angle = idx * 60  #   60  
                            offset_x = offset_distance * np.cos(np.radians(angle))
                            offset_y = offset_distance * np.sin(np.radians(angle))
                            
                            img_x = point_x + offset_x
                            img_y = point_y + offset_y
                            
                            #     
                            plt.plot([point_x, img_x], [point_y, img_y], 
                                   'k--', linewidth=0.5, alpha=0.7)
                            
                            #    ( )
                            ab = AnnotationBbox(imagebox, (img_x, img_y), 
                                              frameon=False, boxcoords="data", pad=0)
                            plt.gca().add_artist(ab)
                
                #          ( N)
                if gen_labels is not None:
                    #      
                    real_selected = {}
                    if real_labels is not None:
                        real_selected = select_representative_points(real_pca[:, [pc_x, pc_y]], real_labels, max_per_class=1)
                    
                    #        
                    gen_selected = select_gen_points_near_real(gen_pca[:, [pc_x, pc_y]], gen_labels, real_selected, real_pca[:, [pc_x, pc_y]], max_per_class=1)
                    gen_sample_images = []
                    for label, indices in gen_selected.items():
                        if len(indices) > 0:
                            #  N  
                            if len(gen_sample_images) < max_images_per_class:
                                sample_imgs = load_sample_images(actual_gen_dir, indices[:1], max_images_per_class)  #  1
                                gen_sample_images.extend(sample_imgs)
                    
                    #   plot 
                    for idx, (orig_idx, img_array) in enumerate(gen_sample_images):
                        if orig_idx < len(gen_pca):
                            point_x, point_y = gen_pca[orig_idx, pc_x], gen_pca[orig_idx, pc_y]
                            
                            #    (2 )
                            img_size = 0.3  # plot  30% ( 15% 2)
                            imagebox = OffsetImage(img_array, zoom=img_size)
                            
                            #       (   )
                            offset_distance = 0.3  #    
                            angle = 180 + (idx * 60)  #    (180 )
                            offset_x = offset_distance * np.cos(np.radians(angle))
                            offset_y = offset_distance * np.sin(np.radians(angle))
                            
                            img_x = point_x + offset_x
                            img_y = point_y + offset_y
                            
                            #     
                            plt.plot([point_x, img_x], [point_y, img_y], 
                                   'k--', linewidth=0.5, alpha=0.7)
                            
                            #    ( )
                            ab = AnnotationBbox(imagebox, (img_x, img_y), 
                                              frameon=False, boxcoords="data", pad=0)
                            plt.gca().add_artist(ab)
                
            except Exception as e:
                print(f"Warning: Could not add sample images to plot: {e}")
                #     plot 
            
            # PCA      
            if save_path is None:
                save_path = "."
            
            pca_subdir = os.path.join(save_path, f"pca_analysis_{pca_mode}")
            os.makedirs(pca_subdir, exist_ok=True)
            
            # 1.     
            filename_all = f'pca_visualization_{pca_mode}_{pc_names[pair_idx].lower().replace("-", "")}_all'
            if dataset_name:
                filename_all += f'_{dataset_name}'
            if step is not None:
                filename_all += f'_step_{step}'
            filename_all += '.png'
            
            save_file_all = os.path.join(pca_subdir, filename_all)
            plt.savefig(save_file_all, dpi=300, bbox_inches='tight')
            
            saved_files.append(save_file_all)
            print(f"PCA visualization {pc_names[pair_idx]} (all classes) saved: {save_file_all}")

            # :     (Real=green, Gen=grapefruit)  
            plt.figure(figsize=(14, 10))
            # Real: green, Gen: grapefruit-like (#FF6F61)
            base_alpha = 0.7
            try:
                if real_pca is not None and real_pca.shape[0] > 0:
                    plt.scatter(
                        real_pca[:, pc_x], real_pca[:, pc_y],
                        alpha=base_alpha, s=30, c=['#2ecc71'], marker='o',
                        label='Real', edgecolors='black', linewidth=0.5,
                    )
                if gen_pca is not None and gen_pca.shape[0] > 0:
                    plt.scatter(
                        gen_pca[:, pc_x], gen_pca[:, pc_y],
                        alpha=base_alpha, s=30, c=['#FF6F61'], marker='^',
                        label='Generated', edgecolors='black', linewidth=0.5,
                    )
            except Exception as _e:
                pass

            plt.xlabel(f'PC{pc_x+1} ({pca.explained_variance_ratio_[pc_x]:.1%} variance)')
            plt.ylabel(f'PC{pc_y+1} ({pca.explained_variance_ratio_[pc_y]:.1%} variance)')
            if pca_mode == 'feature':
                title_duo = f'PCA Visualization {pc_names[pair_idx]} ({ext_model_name} Features): Real vs Gen'
            else:
                title_duo = f'PCA Visualization {pc_names[pair_idx]} (Raw Pixels): Real vs Gen'
            if dataset_name:
                title_duo += f' ({dataset_name})'
            if step is not None:
                title_duo += f', step {step}'
            plt.title(title_duo, fontsize=14)
            plt.legend(loc='upper right')
            plt.grid(True, alpha=0.3)

            filename_all_duo = f'pca_visualization_{pca_mode}_{pc_names[pair_idx].lower().replace("-", "")}_all_duo'
            if dataset_name:
                filename_all_duo += f'_{dataset_name}'
            if step is not None:
                filename_all_duo += f'_step_{step}'
            filename_all_duo += '.png'

            save_file_all_duo = os.path.join(pca_subdir, filename_all_duo)
            plt.savefig(save_file_all_duo, dpi=300, bbox_inches='tight')
            plt.close()
            saved_files.append(save_file_all_duo)
            print(f"PCA visualization {pc_names[pair_idx]} (all classes duo-color) saved: {save_file_all_duo}")
            
            # 2.    
            print(f"Checking conditions for class-wise saving: gen_labels is not None: {gen_labels is not None}, real_labels is not None: {real_labels is not None}")
            if real_labels is not None:  #      
                print("Real data labels available. Creating class-wise visualizations...")
            else:
                print("Real data labels not available. Skipping class-wise visualizations.")
                
            if real_labels is not None:
                for class_idx in range(min(10, num_classes)):  #  10 
                    plt.figure(figsize=(12, 8))
                    
                    #       
                    real_mask = real_labels == class_idx
                    gen_mask = gen_labels == class_idx if gen_labels is not None else np.zeros(len(gen_pca), dtype=bool)
                    
                    print(f"Class {class_idx}: real_mask has {np.sum(real_mask)} points, gen_mask has {np.sum(gen_mask)} points")
                    if np.any(real_mask) or np.any(gen_mask):  #       
                        # EMD  (   vs  ) - Raw image/feature 
                        emd_value = float('inf')
                        if np.any(real_mask) and np.any(gen_mask):
                            # Raw image/feature  EMD  (PCA   )
                            real_features_class = real_array[real_mask]  #   
                            gen_features_class = gen_array[gen_mask]     #   
                            emd_value = compute_emd_high_dim(real_features_class, gen_features_class)
                        
                        if np.any(real_mask):
                            plt.scatter(real_pca[real_mask, pc_x], real_pca[real_mask, pc_y], 
                                       alpha=0.7, s=40, c=['#2ecc71'], 
                                       marker='o', label='Real Data',
                                       edgecolors='black', linewidth=0.5)
                        
                        if gen_labels is not None and np.any(gen_mask):
                            plt.scatter(gen_pca[gen_mask, pc_x], gen_pca[gen_mask, pc_y], 
                                       alpha=0.7, s=40, c=['#FF6F61'], 
                                       marker='^', label='Generated Data',
                                       edgecolors='black', linewidth=0.5)
                            print(f"  -> Added {np.sum(gen_mask)} generated points for class {class_idx}")
                        elif gen_labels is None:
                            print(f"  -> gen_labels is None, skipping generated data for class {class_idx}")
                        else:
                            print(f"  -> No generated points for class {class_idx}")
                        
                        plt.xlabel(f'PC{pc_x+1} ({pca.explained_variance_ratio_[pc_x]:.1%} variance)')
                        plt.ylabel(f'PC{pc_y+1} ({pca.explained_variance_ratio_[pc_y]:.1%} variance)')
                        
                        class_name = class_names[class_idx] if class_idx < len(class_names) else f'Class {class_idx}'
                        
                        #   /   
                        real_count = np.sum(real_mask) if np.any(real_mask) else 0
                        gen_count = np.sum(gen_mask) if gen_labels is not None and np.any(gen_mask) else 0
                        
                        if pca_mode == 'feature':
                            title = f'PCA {pc_names[pair_idx]} ({ext_model_name}): {class_name} (Real:{real_count}, Gen:{gen_count})'
                        else:
                            title = f'PCA {pc_names[pair_idx]} (Raw Pixels): {class_name} (Real:{real_count}, Gen:{gen_count})'
                        
                        # EMD   
                        if emd_value != float('inf'):
                            title += f' (EMD: {emd_value:.4f})'
                        else:
                            title += ' (EMD: N/A)'
                        
                        if dataset_name:
                            title += f' ({dataset_name})'
                        if step is not None:
                            title += f', step {step}'
                        plt.title(title, fontsize=14)
                        
                        plt.legend()
                        plt.grid(True, alpha=0.3)
                        
                        #      (/   3)
                        try:
                            #    
                            if np.any(real_mask):
                                real_class_indices = np.where(real_mask)[0][:max_images_per_class]
                                real_imgs = load_sample_images(actual_real_dir, real_class_indices, max_images_per_class)
                                
                                for img_idx, (orig_idx, img_array) in enumerate(real_imgs):
                                    if orig_idx < len(real_pca):
                                        point_x, point_y = real_pca[orig_idx, pc_x], real_pca[orig_idx, pc_y]
                                        
                                        #   2 
                                        img_size = 0.4  #  view   ( 0.2 2)
                                        imagebox = OffsetImage(img_array, zoom=img_size)
                                        
                                        #      
                                        offset_distance = 0.4  #  view   
                                        angle = img_idx * 120  #   120   (3)
                                        offset_x = offset_distance * np.cos(np.radians(angle))
                                        offset_y = offset_distance * np.sin(np.radians(angle))
                                        
                                        img_x = point_x + offset_x
                                        img_y = point_y + offset_y
                                        
                                        #     
                                        plt.plot([point_x, img_x], [point_y, img_y], 
                                               'k--', linewidth=0.5, alpha=0.7)
                                        
                                        #    ( )
                                        ab = AnnotationBbox(imagebox, (img_x, img_y), 
                                                          frameon=False, boxcoords="data", pad=0)
                                        plt.gca().add_artist(ab)
                            
                            #     (     )
                            if gen_labels is not None and np.any(gen_mask):
                                #     
                                if np.any(real_mask):
                                    real_class_coords = real_pca[real_mask][:, [pc_x, pc_y]]
                                    real_class_indices = np.where(real_mask)[0]
                                    
                                    #       
                                    gen_class_coords = gen_pca[gen_mask][:, [pc_x, pc_y]]
                                    gen_class_indices = np.where(gen_mask)[0]
                                    
                                    if len(real_class_coords) > 0 and len(gen_class_coords) > 0:
                                        #         
                                        min_distances = []
                                        for gen_coord in gen_class_coords:
                                            distances_to_real = np.linalg.norm(real_class_coords - gen_coord, axis=1)
                                            min_distances.append(np.min(distances_to_real))
                                        
                                        min_distances = np.array(min_distances)
                                        #      
                                        sorted_indices = np.argsort(min_distances)[:max_images_per_class]  #  max_images_per_class
                                        gen_class_indices = gen_class_indices[sorted_indices]
                                    else:
                                        gen_class_indices = gen_class_indices[:max_images_per_class]  #  
                                else:
                                    gen_class_indices = np.where(gen_mask)[0][:max_images_per_class]  #  
                                
                                gen_imgs = load_sample_images(actual_gen_dir, gen_class_indices, max_images_per_class)
                                
                                for img_idx, (orig_idx, img_array) in enumerate(gen_imgs):
                                    if orig_idx < len(gen_pca):
                                        point_x, point_y = gen_pca[orig_idx, pc_x], gen_pca[orig_idx, pc_y]
                                        
                                        #   2 
                                        img_size = 0.4  #  view   ( 0.2 2)
                                        imagebox = OffsetImage(img_array, zoom=img_size)
                                        
                                        #       (   )
                                        offset_distance = 0.4  #  view   
                                        angle = 180 + (img_idx * 120)  #    (180 )
                                        offset_x = offset_distance * np.cos(np.radians(angle))
                                        offset_y = offset_distance * np.sin(np.radians(angle))
                                        
                                        img_x = point_x + offset_x
                                        img_y = point_y + offset_y
                                        
                                        #     
                                        plt.plot([point_x, img_x], [point_y, img_y], 
                                               'k--', linewidth=0.5, alpha=0.7)
                                        
                                        #    ( )
                                        ab = AnnotationBbox(imagebox, (img_x, img_y), 
                                                          frameon=False, boxcoords="data", pad=0)
                                        plt.gca().add_artist(ab)
                                        
                        except Exception as e:
                            print(f"Warning: Could not add sample images to class {class_idx} plot: {e}")
                            #     plot 
                        
                        #    (EMD  )
                        if emd_value != float('inf'):
                            emd_str = f'_emd{emd_value:.4f}'
                        else:
                            emd_str = '_emdNA'
                        
                        filename_class = f'pca_visualization_{pca_mode}_{pc_names[pair_idx].lower().replace("-", "")}_class{class_idx}{emd_str}'
                        if dataset_name:
                            filename_class += f'_{dataset_name}'
                        if step is not None:
                            filename_class += f'_step_{step}'
                        filename_class += '.png'
                        
                        save_file_class = os.path.join(pca_subdir, filename_class)
                        plt.savefig(save_file_class, dpi=300, bbox_inches='tight')
                        plt.close()
                        
                        saved_files.append(save_file_class)
                        if emd_value != float('inf'):
                            print(f"PCA visualization {pc_names[pair_idx]} (class {class_idx}: {class_name}, EMD: {emd_value:.4f}) saved: {save_file_class}")
                        else:
                            print(f"PCA visualization {pc_names[pair_idx]} (class {class_idx}: {class_name}, EMD: N/A) saved: {save_file_class}")
            
            plt.close()  #   figure 
        
        #    
        total_variance = pca.explained_variance_ratio_.sum()
        print(f"Total variance explained by {n_components} components: {total_variance:.1%}")
        for i in range(min(8, n_components)):
            print(f"PC{i+1}: {pca.explained_variance_ratio_[i]:.1%}")
        
        #        (  )
        main_save_file = saved_files[0] if saved_files else None
        
        #   
        log_message = ""
        if real_labels is None:
            log_message = "Real data labels not available - class-wise analysis skipped"
        
        # tuple     (   3  )
        return pca.explained_variance_ratio_, main_save_file, log_message
        
    except Exception as e:
        print(f"Error in PCA visualization: {e}")
        import traceback
        traceback.print_exc()
        return None, None, "Error occurred during PCA visualization"
        
    finally:
        #   
        try:
            shutil.rmtree(temp_dir, ignore_errors=True)
        except Exception:
            pass