import copy
import random
from typing import List, Callable, Tuple
import os
import json
from datetime import datetime

import math
import pandas as pd
from sklearn.metrics.pairwise import rbf_kernel, euclidean_distances
import numpy as np
import torch
from avalanche.benchmarks.utils import AvalancheDataset
from torch import nn
from torchvision import transforms, models

# from MERS.sampling_strategies.herding import HerdingSelectionStrategy
from MERS.sampling_strategies.teal import TEALExemplarsSelectionStrategy, calculate_typicality, kmeans
from sklearn.metrics.pairwise import cosine_distances
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances

import matplotlib.pyplot as plt


def knn_density(embeddings: np.ndarray,k: int = 5, metric:str='cosine') -> np.ndarray:
    """
    Compute k-NN density scores for each point in `embeddings`.
    Density score for point i is defined as:
        rho_i = k / sum_{j=1}^k d_{i, j}
    where d_{i, j} is the distance from point i to its j-th nearest neighbor.
    Parameters
    ----------
    embeddings : np.ndarray, shape (n_samples, n_features)
        The learned representations / embedding vectors.
    k : int
        Number of neighbors to use for density estimation.
    Returns
    -------
    densities : np.ndarray, shape (n_samples,)
        Density score rho_i for each point.
    """
    # fit NN model (we add 1 because the first neighbor of each point is itself at distance 0)
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='auto', metric=metric).fit(embeddings)
    distances, _ = nbrs.kneighbors(embeddings)

    # drop the zero-th column (distance to itself)
    knn_distances = distances[:, 1:]  # shape (n_samples, k)

    # compute density: k / sum of distances to k nearest neighbors
    densities = 1 / np.sum(knn_distances, axis=1)

    return densities

def _as_2d_float(x):
    x = np.asarray(x, dtype=float)
    if x.ndim == 1: x = x.reshape(-1, 1)
    return x

def _l2_normalize_rows(X, eps=1e-12):
    return X / (np.linalg.norm(X, axis=1, keepdims=True) + eps)

def median_radial_density_fullSi(embeddings: np.ndarray) -> np.ndarray:
    """
    Cosine metric. For each point i, define S_i = {x : d_cos(x, centroid) <= d_cos(x_i, centroid)}.
    Within S_i, compute densities for all s in S_i using ALL other points in S_i:
        rho_s = 1 / sum_{t in S_i \ {s}} d_cos(s, t)
    Return the median rho_s over s in S_i. (This corresponds to K = |S_i|-1.)
    """
    X = _as_2d_float(embeddings)
    Xn = _l2_normalize_rows(X)

    # cosine-centroid (normalize mean direction)
    c = Xn.mean(axis=0, keepdims=True)
    c /= (np.linalg.norm(c, axis=1, keepdims=True) + 1e-12)

    # distances to centroid
    d2c = pairwise_distances(Xn, c, metric="cosine").ravel()

    # pairwise cosine distances (self set to inf to ease exclusion)
    D = pairwise_distances(Xn, Xn, metric="cosine")
    np.fill_diagonal(D, np.inf)

    n = len(Xn)
    out = np.empty(n, dtype=float)

    for i in range(n):
        idx = np.flatnonzero(d2c <= d2c[i])   # S_i
        m = idx.size
        if m <= 1:
            out[i] = 0.0
            continue

        # K = |S_i|-1: sum all finite distances inside S_i for each s
        densities = np.empty(m, dtype=float)
        for t, s in enumerate(idx):
            drow = D[s, idx]
            # exclude self (inf) and sum the rest
            ssum = np.sum(drow[np.isfinite(drow)])
            densities[t] = 1.0 / (ssum + 1e-12)

        out[i] = float(np.median(densities))

    return out

def knn_based_delta(X, max_size, embedding_type='unknown', metric='cosine', scale_method='cv'):
    """
    Delta calculation based on k-NN distances.
    Base delta = median k-NN distance, scaled by embedding characteristics.
    
    Parameters:
    -----------
    X : array-like
        Feature embeddings
    max_size : int
        Maximum size for selection
    embedding_type : str
        Type of embedding (for logging)
    metric : str
        Distance metric to use
    scale_method : str
        Scaling method: 'cv', 'skewness', 'cv_skewness', 'inverse_mean', 'percentile_25'
    """
    from sklearn.neighbors import NearestNeighbors
    from sklearn.metrics import pairwise_distances
    from scipy.stats import skew
    
    # Calculate k based on max_size (same as in optimal_delta)
    k = min(len(X) - 1, len(X) // max_size)
    
    # Get k-NN distances
    nn_model = NearestNeighbors(n_neighbors=k + 1, metric=metric).fit(X)
    distances, _ = nn_model.kneighbors(X)
    
    # Remove self-distance (first column)
    knn_distances = distances[:, 1:]  # Shape: (n_samples, k)
    
    # Base delta = median k-NN distance
    base_delta = np.median(knn_distances)
    
    # Calculate embedding characteristics for scaling
    all_distances = pairwise_distances(X, metric=metric)
    upper_triangle = np.triu_indices_from(all_distances, k=1)
    distances_flat = all_distances[upper_triangle]
    
    # Calculate characteristics
    mean_dist = np.mean(distances_flat)
    std_dist = np.std(distances_flat)
    
    # Apply different scaling methods
    if scale_method == 'cv':
        # Option 1: Use coefficient of variation to decrease base delta
        # CV = std / mean (lower CV = more consistent/compact embeddings)
        # Dense embeddings (low CV) -> small scale factor -> smaller final delta
        # Sparse embeddings (high CV) -> large scale factor -> larger final delta
        cv = std_dist / (mean_dist + 1e-8)
        scale_factor = cv  # Higher CV -> larger scale factor
        print(f"KNN-based delta (CV): k={k}, base_delta={base_delta:.4f}, mean_dist={mean_dist:.4f}, cv_scale={scale_factor:.4f}, final={base_delta * scale_factor:.4f}")
        
    elif scale_method == 'skewness':
        # Option 2: Use skewness-based scaling
        distance_skewness = skew(distances_flat)
        # Positive skewness (right tail) -> shrink delta (more precise selection)
        # Negative skewness (left tail) -> expand delta (broader coverage)
        if distance_skewness > 0:
            # Positive skewness: shrink the delta
            scale_factor = 1.0 / (1.0 + distance_skewness)  # Shrinks as skewness increases
        else:
            # Negative skewness: expand the delta
            scale_factor = 1.0 + abs(distance_skewness)  # Expands as negative skewness increases
        print(f"KNN-based delta (Skewness): k={k}, base_delta={base_delta:.4f}, skewness={distance_skewness:.4f}, skew_scale={scale_factor:.4f}, final={base_delta * scale_factor:.4f}")
        
    elif scale_method == 'cv_skewness':
        # Option 3: Combine CV and skewness scaling
        cv = std_dist / (mean_dist + 1e-8)
        distance_skewness = skew(distances_flat)
        
        # Apply skewness-based scaling
        if distance_skewness > 0:
            skew_scale = 1.0 / (1.0 + distance_skewness)
        else:
            skew_scale = 1.0 + abs(distance_skewness)
        
        # Combine CV and skewness scaling
        scale_factor = cv * skew_scale
        print(f"KNN-based delta (CV+Skewness): k={k}, base_delta={base_delta:.4f}, mean_dist={mean_dist:.4f}, skewness={distance_skewness:.4f}, cv_scale={cv:.4f}, skew_scale={skew_scale:.4f}, final_scale={scale_factor:.4f}, final={base_delta * scale_factor:.4f}")
        
    elif scale_method == 'inverse_mean':
        # Option 4: Inverse of mean distance (simple density measure)
        scale_factor = 1.0 / (mean_dist + 1e-8)
        print(f"KNN-based delta (Inverse Mean): k={k}, base_delta={base_delta:.4f}, mean_dist={mean_dist:.4f}, inverse_mean_scale={scale_factor:.4f}, final={base_delta * scale_factor:.4f}")
        
    elif scale_method == 'percentile_25':
        # Option 5: Use percentile-based scaling
        p25 = np.percentile(distances_flat, 25)
        scale_factor = p25  # Use 25th percentile as scale factor
        print(f"KNN-based delta (25th Percentile): k={k}, base_delta={base_delta:.4f}, p25={p25:.4f}, final={base_delta * scale_factor:.4f}")
        
    else:
        # Default to CV method
        cv = std_dist / (mean_dist + 1e-8)
        scale_factor = cv
        print(f"KNN-based delta (Default CV): k={k}, base_delta={base_delta:.4f}, mean_dist={mean_dist:.4f}, cv_scale={scale_factor:.4f}, final={base_delta * scale_factor:.4f}")
    
    final_delta = base_delta * scale_factor
    return final_delta


def calculate_weight(max_size, remaining_features, weights, weight_method='ratio_median_knn_density_k_1'):
    k = min(len(remaining_features[0]) - 1, len(remaining_features[0]) // max_size)
    if weight_method == 'ratio_median_knn_density_k_1':
        weights[0] = np.median(knn_density(remaining_features[0], k)) / np.median(
            knn_density(remaining_features[0], 1))
        weights[1] = np.median(knn_density(remaining_features[1], k)) / np.median(
            knn_density(remaining_features[1], 1))

        print("The weight method is ratio_median_knn_density_k_1")
    if weight_method=='heuristic':
        weights[0] = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=remaining_features[1])
        weights[1] = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=len(remaining_features[0]), X=remaining_features[0])
        # weights[0]=len(remaining_features[0])
    if weight_method=='importance_sampling':
        weights[0]=knn_density(remaining_features[0], k)
        weights[1]=knn_density(remaining_features[1], k)
        print("The weight method is importance_sampling_ess")
    if weight_method == 'euclidean_ratio_median_knn_density_k_1':
        metric = "euclidean"
        weights[0] = np.median(knn_density(remaining_features[0], k, metric)) / np.median(
            knn_density(remaining_features[0], 1, metric))
        weights[1] = np.median(knn_density(remaining_features[1], k, metric)) / np.median(
            knn_density(remaining_features[1], 1, metric))
        print("The weight method is euclidean_ratio_median_knn_density_k_1")
    elif weight_method == 'ratio_median_knn_density_1_k':
        weights[0] = np.median(knn_density(remaining_features[0], 1)) / np.median(
            knn_density(remaining_features[0], k))
        weights[1] = np.median(knn_density(remaining_features[1], 1)) / np.median(
            knn_density(remaining_features[1], k))
        print("The weight method is ratio_median_knn_density_1_k")
    elif weight_method == 'median_knn_density_knn':
        weights[0] = np.median(knn_density(remaining_features[0], k))
        weights[1] = np.median(knn_density(remaining_features[1], k))
        print("The weight method is median_knn_density_knn")
    elif weight_method == 'euclidean_median_knn_density_knn':
        metric = "euclidean"

        weights[0] = np.median(knn_density(remaining_features[0], k,metric))
        weights[1] = np.median(knn_density(remaining_features[1], k, metric))
    elif weight_method == 'median_knn_density_1':
        weights[0] = np.median(knn_density(remaining_features[0], 1))
        weights[1] = np.median(knn_density(remaining_features[1], 1))
        print("The weight method is median_knn_density_1")
    elif weight_method == 'inverse_median_knn_density_knn':
        weights[0] = 1 / np.median(knn_density(remaining_features[0], k))
        weights[1] = 1/np.median(knn_density(remaining_features[1], k))
        print("The weight method is inverse_median_knn_density_knn")
    elif weight_method == 'inverse_median_knn_density_1':
        weights[0] = 1 / np.median(knn_density(remaining_features[0], 1))
        weights[1] = 1/np.median(knn_density(remaining_features[1], 1))
        print("The weight method is inverse_median_knn_density_1")
    elif weight_method == 'ratio_median_knn_density_k_1_only_supervised':
        weights[0] = np.median(knn_density(remaining_features[0], k)) / np.median(
            knn_density(remaining_features[0], 1))
        weights[1] = 1
        print("The weight method is ratio_median_knn_density_k_1_only_supervised")
    elif weight_method == 'ratio_median_knn_density_1_k_only_supervised':
        weights[0] = np.median(knn_density(remaining_features[0], 1)) / np.median(
            knn_density(remaining_features[0], k))
        weights[1] = 1
        print("The weight method is ratio_median_knn_density_1_k_only_supervised")
    elif weight_method=='knn_all':
        weights[0] = np.median(knn_density(remaining_features[0], len(remaining_features[0])-1))
        weights[1] =np.median(knn_density(remaining_features[1], len(remaining_features[1])-1))
        print("The weight method is median_cosine")
    elif weight_method == 'median_knn_density_radial_to_centroid':
        # cosine metric, adaptive k per radial region
        w0 = median_radial_density_fullSi(remaining_features[0])
        w1 = median_radial_density_fullSi(remaining_features[1])
        weights[0] = float(np.median(w0))
        weights[1] = float(np.median(w1))
        print("The weight method is median_knn_density_radial_to_centroid (cosine, adaptive k)")
    elif weight_method == 'adaptive_entropy':
        # Adaptive weighting based on feature entropy/diversity
        def feature_entropy(X):
            # Compute entropy-based measure of feature diversity
            from scipy.stats import entropy
            # Discretize features for entropy calculation
            X_discrete = np.digitize(X, np.percentile(X, [25, 50, 75]), right=True)
            entropies = []
            for i in range(X.shape[1]):
                _, counts = np.unique(X_discrete[:, i], return_counts=True)
                entropies.append(entropy(counts))
            return np.mean(entropies)
        
        try:
            eff1 = herding_effectiveness(remaining_features[0], max_size)
            eff2 = herding_effectiveness(remaining_features[1], max_size)
            total_eff = eff1 + eff2
            if total_eff > 0:
                weights[0] = eff1 / total_eff
                weights[1] = eff2 / total_eff
            print(f"The weight method is herding_specific_balance: {weights}")
        except Exception as e:
            print(f"Herding specific balance failed: {e}, using equal weights")
    
    elif weight_method == 'mean_approximation_quality':
        # Weight based on how well each space's mean can be approximated by subsets
        def mean_approx_quality(X, max_size):
            true_mean = np.mean(X, axis=0)
            # Try different subset sizes
            subset_sizes = [1, max_size//4, max_size//2, max_size]
            best_approx = float('inf')
            
            for size in subset_sizes:
                if size >= len(X) or size < 1:
                    continue
                # Sample a few random subsets of this size
                for _ in range(min(20, len(X))):
                    indices = np.random.choice(len(X), size, replace=False)
                    subset_mean = np.mean(X[indices], axis=0)
                    dist = np.linalg.norm(subset_mean - true_mean)
                    best_approx = min(best_approx, dist)
            
            return 1.0 / (best_approx + 1e-8)
        
        try:
            qual1 = mean_approx_quality(remaining_features[0], max_size)
            qual2 = mean_approx_quality(remaining_features[1], max_size)
            total_qual = qual1 + qual2
            if total_qual > 0:
                weights[0] = qual1 / total_qual
                weights[1] = qual2 / total_qual
            print(f"The weight method is mean_approximation_quality: {weights}")
        except Exception as e:
            print(f"Mean approximation quality failed: {e}, using equal weights")
    
    elif weight_method == 'spread_vs_compactness':
        # Weight based on the ratio of spread to compactness in each space
        def spread_compactness_ratio(X):
            # Calculate spread (variance) and compactness (mean distance to center)
            center = np.mean(X, axis=0)
            distances_to_center = np.linalg.norm(X - center, axis=1)
            compactness = np.mean(distances_to_center)

            X_sample = X
            
            pairwise_dists = []
            for i in range(len(X_sample)):
                for j in range(i+1, len(X_sample)):
                    pairwise_dists.append(np.linalg.norm(X_sample[i] - X_sample[j]))
            
            spread = np.mean(pairwise_dists) if pairwise_dists else 1.0
            return spread / (compactness + 1e-8)
        
        try:
            ratio1 = spread_compactness_ratio(remaining_features[0])
            ratio2 = spread_compactness_ratio(remaining_features[1])
            total_ratio = ratio1 + ratio2
            if total_ratio > 0:
                weights[0] = ratio1 / total_ratio
                weights[1] = ratio2 / total_ratio
            print(f"The weight method is spread_vs_compactness: {weights}")
        except Exception as e:
            print(f"Spread vs compactness failed: {e}, using equal weights")
    
    elif weight_method == '':
        print("The weight method is not set, using args weights")
    
    return k


def save_weights_to_file(weights, weight_method, episode_info, exp_dir, class_id=None, task_id=None):
    """
    Save weights for model-based and self-supervised features to a JSON file.
    
    Parameters:
    -----------
    weights : list
        List containing [model_based_weight, self_supervised_weight]
    weight_method : str
        Method used to calculate weights
    episode_info : dict
        Dictionary containing episode information (dataset, seed, etc.)
    exp_dir : str
        Experiment directory to save the weights file
    class_id : int, optional
        Class ID for class-specific weight saving
    task_id : int, optional
        Task ID for task-specific weight saving
    """
    # Create weights directory if it doesn't exist
    weights_dir = os.path.join(exp_dir, "weights")
    os.makedirs(weights_dir, exist_ok=True)
    
    # Prepare weight data
    weight_data = {
        "timestamp": datetime.now().isoformat(),
        "episode_info": episode_info,
        "weight_method": weight_method,
        "model_based_weight": float(weights[0]),
        "self_supervised_weight": float(weights[1]),
        "class_id": class_id,
        "task_id": task_id,
        "total_weight": float(weights[0] + weights[1]),
        "weight_ratio": float(weights[0] / weights[1]) if weights[1] != 0 else float('inf'),
        "is_nvidia": episode_info.get('features_type', '').endswith('_nvidia')
    }
    
    # Create filename
    if class_id is not None:
        filename = f"weights_class_{class_id}_episode_{episode_info.get('episode', 'unknown')}.json"
    elif task_id is not None:
        filename = f"weights_task_{task_id}_episode_{episode_info.get('episode', 'unknown')}.json"
    else:
        filename = f"weights_episode_{episode_info.get('episode', 'unknown')}.json"
    
    filepath = os.path.join(weights_dir, filename)
    
    # Save to JSON file
    with open(filepath, 'w') as f:
        json.dump(weight_data, f, indent=2)
    
    print(f"Weights saved to: {filepath}")
    print(f"Model-based weight: {weights[0]:.6f}")
    print(f"Self-supervised weight ({episode_info.get('features_type', 'unknown')}): {weights[1]:.6f}")
    print(f"Weight ratio (mb/ss): {weight_data['weight_ratio']:.6f}")
    if weight_data['is_nvidia']:
        print(f"Using NVIDIA implementation: {episode_info.get('features_type', 'unknown')}")
    
    return filepath


def save_episode_weights_summary(episode_weights, episode_info, exp_dir):
    """
    Save a summary of all weights for an entire episode.
    
    Parameters:
    -----------
    episode_weights : list
        List of weight dictionaries for each class in the episode
    episode_info : dict
        Dictionary containing episode information
    exp_dir : str
        Experiment directory to save the summary file
    """
    # Create weights directory if it doesn't exist
    weights_dir = os.path.join(exp_dir, "weights")
    os.makedirs(weights_dir, exist_ok=True)
    
    # Calculate episode statistics
    mb_weights = [w['model_based_weight'] for w in episode_weights]
    ss_weights = [w['self_supervised_weight'] for w in episode_weights]
    weight_ratios = [w['weight_ratio'] for w in episode_weights if w['weight_ratio'] != float('inf')]
    
    # Helper function to safely calculate statistics
    def safe_stats(data, name):
        if len(data) == 0:
            return {
                "mean": 0.0,
                "std": 0.0,
                "min": 0.0,
                "max": 0.0,
                "median": 0.0
            }
        return {
            "mean": float(np.mean(data)),
            "std": float(np.std(data)),
            "min": float(np.min(data)),
            "max": float(np.max(data)),
            "median": float(np.median(data))
        }
    
    episode_summary = {
        "timestamp": datetime.now().isoformat(),
        "episode_info": episode_info,
        "total_classes": len(episode_weights),
        "class_weights": episode_weights,
        "statistics": {
            "model_based": safe_stats(mb_weights, "model_based"),
            "self_supervised": safe_stats(ss_weights, "self_supervised"),
            "weight_ratios": safe_stats(weight_ratios, "weight_ratios")
        }
    }
    
    # Create filename
    filename = f"episode_weights_summary_episode_{episode_info.get('episode', 'unknown')}.json"
    filepath = os.path.join(weights_dir, filename)
    
    # Save to JSON file
    with open(filepath, 'w') as f:
        json.dump(episode_summary, f, indent=2)
    
    print(f"Episode weights summary saved to: {filepath}")
    print(f"Episode {episode_info.get('episode', 'unknown')} - {len(episode_weights)} classes processed")
    print(f"Average MB weight: {episode_summary['statistics']['model_based']['mean']:.6f}")
    print(f"Average SS weight: {episode_summary['statistics']['self_supervised']['mean']:.6f}")
    
    return filepath


def pretrained_representations(data, dataset, ss_method, seed, order=None, nvidia=None):
    class_indicator = list(data.targets.uniques)[0]
    if not order:
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/regular_order/representations_trained_{dataset}_{ss_method}_seed_{seed}/{dataset}_{ss_method}_all_{class_indicator}.npy',
            allow_pickle=True)
    else:
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/order_seed_{order}/representations_trained_{dataset}_{ss_method}_seed_{seed}/{dataset}_{ss_method}_all_{class_indicator}.npy',
            allow_pickle=True)
        print(f"Loading representations for dataset {dataset} with order {order} and seed {seed}")
    if nvidia:
        if dataset=="tinyimg":
            dataset="tinyimagenet"
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/dino_nvidia/representations_trained_{dataset}_dinov2_torch_dinov2_vitb14_seed_0/{dataset}_dinov2_torch_dinov2_vitb14_all_{class_indicator}.npy',
            allow_pickle=True)
        print(f"Loading representations for dataset {dataset} with nvidia dino")
    features = obj.item()['features']
    norms = np.linalg.norm(features, axis=1, keepdims=True)
    features = features / np.maximum(norms, 1)
    print(f"features of class {class_indicator} loaded, length:{len(features)}")
    return features


def calculate_remaining_indices(features, features_ss, integrated, remaining_features, ss_method):
    if ss_method == 'model_based' or integrated:
        print(f"{ss_method} features are used, {integrated} integrated features are used")
        remaining_features.append([features.copy() for features in features])
        remaining_indices = list(np.arange(len(features)))
    if (ss_method == 'dino' or ss_method == 'simclr' or ss_method == 'vicreg'):
        print(f"{ss_method} features are used")
        remaining_features.append([features.copy() for features in features_ss])
        remaining_indices = list(np.arange(len(features_ss)))
    return remaining_indices


class MaxHerding(TEALExemplarsSelectionStrategy):
    def __init__(self,
                 args, device):
        """
        Initialize K-Medoids Selection Strategy.

        Parameters:
        -----------
        features_type: str
            Type of features to use: 'model_based', 'dino', 'simclr', or 'vicreg'
        dataset_name: str
            Name of the dataset
        integrated_features: bool
            Whether to use integrated features (multiple feature types)
        alpha: float
            The weight for the first feature type when using integrated features
        seed: int
            Random seed for reproducibility
        max_iterations: int
            Maximum number of iterations for k-medoids
        tol: float
            Tolerance for convergence in k-medoids
        device: str
            Device to use for computation
        """
        self.args=args
        self.ss_method = args.features_type
        self.dataset_name = args.dataset
        self.integrated = args.integrated_features
        self.integrated = self.integrated.strip().lower() == "true"
        self.alpha = args.alpha
        self.seed = args.seed
        self.max_iterations = 10
        self.tol = 1e-4
        self.device = device
        self.weight_method = args.weight_method

        self.features = None
        self.features_ss = None
        self.buffer_indices = []
        self.new_order = []
        self.group_to_len = {}  # Dictionary to track class lengths
        self.episode = 0  # Track current episode
        self.episode_weights = []  # Store weights for current episode

        # Set random seed for reproducibility
        np.random.seed(self.seed)
        random.seed(self.seed)

        print(f"Initialized Max Herding with:")
        print(f"  - Features type: {self.ss_method}")
        print(f"  - Integrated features: {self.integrated}")
        print(f"  - Alpha: {self.alpha}")
        print(f"  - Max iterations: {self.max_iterations}")
        print(f"  - Tolerance: {self.tol}")

    def init_features(self, data, model):
        if self.integrated:
            TEALExemplarsSelectionStrategy.init_features(self, data, model)
            norms = np.linalg.norm(self.features, axis=1, keepdims=True)
            self.features = self.features / np.maximum(norms, 1)
            features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed, order=self.args.order, nvidia=self.args.nvidia)
            features = np.vstack(features)
            features = torch.tensor(features)
            self.features_ss = features.cpu().numpy()
        # elif self.ss_method == 'model_based':
        #     TEALExemplarsSelectionStrategy.init_features(self, data, model)
        # elif self.ss_method == 'dino' or self.ss_method == 'simclr' or self.ss_method == 'vicreg':
        #     features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed)
        #     features = np.vstack(features)
        #     features = torch.tensor(features)
        #     self.features_ss = features.cpu().numpy()

    def greedy_k_herding(self,
                         X: np.ndarray,
                         b: int,
                         kernel_func: Callable) -> np.ndarray:
        """
        Fast, fully-vectorized greedy max-herding:
          - X: (n, d) array of features
          - b: budget (# points to pick)
          - kernel_func: if it supports array inputs, you can do K = kernel_func(X, X);
                         otherwise we build K once with one nested loop.
        Returns:
          indices of the b picked points (into X).
        """
        n = X.shape[0]

        # 1) Build Gram matrix K[i,j] = k(X[i], X[j]) exactly once
        try:
            # If kernel_func can take two matrices and return an (n,n) Gram:
            K = kernel_func(X, X)
        except:
            # Fallback to one nested loop (still only O(n^2) once)
            K = np.zeros((n, n), dtype=float)
            for i in range(n):
                # compute row i all at once if possible
                for j in range(n):
                    K[i, j] = kernel_func(X[i:i + 1], X[j:j + 1])

        # 2) Initialize coverage scores and pick list
        k_i = np.zeros(n, dtype=float)
        picked = []

        # 3) Greedy rounds
        for _ in range(min(b, n)):
            # Compute marginal gains for **all** j in one shot:
            #   G[j] = sum_i max(K[i,j] - k_i[i], 0)
            gains = np.maximum(K - k_i[:, None], 0.0).sum(axis=0)

            # Exclude already-picked indices
            gains[picked] = -np.inf

            # Pick the best
            j_star = int(np.argmax(gains))
            picked.append(j_star)

            # Update coverage scores: k_i = max(k_i, K[:, j_star])
            k_i = np.maximum(k_i, K[:, j_star])

        return np.array(picked, dtype=int)

    def make_sorted_indices(
            self, strategy: "SupervisedTemplate", data: AvalancheDataset
    ) -> List[int]:
        """
        Make sorted indices using the kernel k-medoids algorithm.

        Parameters:
        -----------
        data: numpy.ndarray or Dict
            Data to select from (used only for length if features already set)
        max_size: int
            Maximum number of samples to select in each iteration (budget)
        cur_class: Any
            Current class ID for class-specific selection
        iterations: int
            Number of iterations for the non-greedy algorithm
        kernel_func: callable
            Custom kernel function to use (optional)

        Returns:
        --------
        List[int]
            Sorted indices with selected points first, followed by remaining points
        """
        # Check if we need to process a specific class
        if len(list(data.targets.uniques)) < 1:
            return []
        cur_class = list(data.targets.uniques)[0]
        self.cur_class = cur_class
        if self.group_to_len.get(cur_class) is not None:
            return list(range(len(data)))
        else:
            self.group_to_len[cur_class] = cur_class
            self.episode = strategy.experience.current_experience
        self.init_features(data, strategy.model)
        weights = [self.alpha, 1 - self.alpha]
        max_size = strategy.plugins[1].storage_policy.max_size // len(strategy.plugins[1].storage_policy.seen_groups)
        remaining_features = []
        ss_method = self.ss_method
        integrated = self.integrated
        features = self.features
        features_ss = self.features_ss
        remaining_indices = calculate_remaining_indices(features, features_ss, integrated, remaining_features,
                                                        ss_method)
        k = calculate_weight(max_size=max_size,
                         remaining_features=remaining_features,
                         weights=weights,
                         weight_method=self.weight_method)
        print(f"weights of model based is: {weights[0]}")
        print(f"weights of self-supervised is: {weights[1]}")
        
        # Save weights to file
        # Determine the actual features type for weight saving
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": self.alpha,
            "episode": self.episode,
            "class_id": cur_class
        }
        
        # Get experiment directory from args if available
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        
        # Save individual class weights
        weight_file = save_weights_to_file(
            weights=weights,
            weight_method=self.weight_method,
            episode_info=episode_info,
            exp_dir=exp_dir,
            class_id=cur_class
        )
        
        # Store weight data for episode summary
        weight_data = {
            "class_id": cur_class,
            "model_based_weight": float(weights[0]),
            "self_supervised_weight": float(weights[1]),
            "weight_ratio": float(weights[0] / weights[1]) if weights[1] != 0 else float('inf'),
            "k_value": int(k),
            "weight_method": self.weight_method,
            "timestamp": datetime.now().isoformat()
        }
        self.episode_weights.append(weight_data)
        max_size = strategy.plugins[1].storage_policy.max_size // len(strategy.plugins[1].storage_policy.seen_groups)
        if self.integrated:
            if self.features is None or self.features_ss is None:
                raise ValueError("Both model-based and self-supervised features must be set for integrated mode.")
            model_features = self.features.copy()
            ss_features = self.features_ss.copy()
            model_dim = model_features.shape[1]

            # Initialize remaining subsets
            U_model = model_features.copy()
            U_ss = ss_features.copy()
            U_indices = np.arange(U_model.shape[0])

            def integrated_kernel(x: np.ndarray, y: np.ndarray) -> np.ndarray:
                x_model = x[:, :model_dim]
                y_model = y[:, :model_dim]
                x_ss = x[:, model_dim:]
                y_ss = y[:, model_dim:]
                if self.args.sigma_mb=='1nn':
                    sigma_mb = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=len(x_model), X=x_model)
                elif self.args.sigma_mb == 'knn':
                    sigma_mb = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=x_model)
                elif self.args.sigma_mb == 'knn_l2':
                    sigma_mb = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=x_model,metric = "euclidean")
                elif self.args.sigma_mb=='1':
                    sigma_mb=1
                elif self.args.sigma_mb=='median_l2':
                    sigma_mb = np.median(pairwise_distances(x_model, metric="euclidean"))
                elif self.args.sigma_mb=='median_cosine':
                    sigma_mb = np.median(pairwise_distances(x_model, metric="cosine"))
                elif self.args.sigma_mb=='knn_based':
                    # Use k-NN based sigma scaling
                    sigma_mb = knn_based_delta(x_model, max_size, 'model_based')
                elif self.args.sigma_mb.startswith('knn_based_'):
                    # Use k-NN based sigma scaling with specific scale method
                    scale_method = self.args.sigma_mb.split('_', 2)[2]  # Extract scale method after 'knn_based_'
                    sigma_mb = knn_based_delta(x_model, max_size, 'model_based', scale_method=scale_method)
                elif self.args.sigma_mb == 'percentile_25_50_l2':
                    # Sigma as the distance between 25th and 50th percentiles of L2 distances
                    l2_distances = pairwise_distances(x_model, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    p25 = np.percentile(distances_flat, 25)
                    p50 = np.percentile(distances_flat, 50)
                    sigma_mb = p50 - p25
                elif self.args.sigma_mb == 'percentile_25_75_half_l2':
                    # Sigma as half the distance between 25th and 75th percentiles of L2 distances
                    l2_distances = pairwise_distances(x_model, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    p25 = np.percentile(distances_flat, 25)
                    p75 = np.percentile(distances_flat, 75)
                    sigma_mb = (p75 - p25) / 2.0
                elif self.args.sigma_mb == 'max_l2':
                    # Sigma as the maximum L2 distance
                    l2_distances = pairwise_distances(x_model, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    sigma_mb = np.max(distances_flat)
                elif self.args.sigma_mb == 'knn_l2':
                    sigma_ss = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=x_ss,metric = "euclidean")
                if self.args.sigma_ss == 'knn':
                    sigma_ss = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=x_ss)
                if self.args.sigma_ss == '1nn':
                    sigma_ss = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=len(x_ss), X=x_ss)
                if self.args.sigma_ss=='1':
                    sigma_ss=1
                elif self.args.sigma_ss=='median_l2':
                   sigma_ss =np.median(pairwise_distances(x_ss, metric="euclidean"))
                elif self.args.sigma_ss=='median_cosine':
                   sigma_ss =np.median(pairwise_distances(x_ss, metric="cosine"))
                elif self.args.sigma_ss=='knn_based':
                    # Use k-NN based sigma scaling
                    sigma_ss = knn_based_delta(x_ss, max_size, self.ss_method)
                elif self.args.sigma_ss.startswith('knn_based_'):
                    # Use k-NN based sigma scaling with specific scale method
                    scale_method = self.args.sigma_ss.split('_', 2)[2]  # Extract scale method after 'knn_based_'
                    sigma_ss = knn_based_delta(x_ss, max_size, self.ss_method, scale_method=scale_method)
                elif self.args.sigma_ss == 'percentile_25_50_l2':
                    # Sigma as the distance between 25th and 50th percentiles of L2 distances
                    l2_distances = pairwise_distances(x_ss, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    p25 = np.percentile(distances_flat, 25)
                    p50 = np.percentile(distances_flat, 50)
                    sigma_ss = p50 - p25
                elif self.args.sigma_ss == 'percentile_25_75_half_l2':
                    # Sigma as half the distance between 25th and 75th percentiles of L2 distances
                    l2_distances = pairwise_distances(x_ss, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    p25 = np.percentile(distances_flat, 25)
                    p75 = np.percentile(distances_flat, 75)
                    sigma_ss = (p75 - p25) / 2.0
                elif self.args.sigma_ss == 'max_l2':
                    # Sigma as the maximum L2 distance
                    l2_distances = pairwise_distances(x_ss, metric="euclidean")
                    # Get upper triangle to avoid duplicates and self-distances
                    upper_triangle = np.triu_indices_from(l2_distances, k=1)
                    distances_flat = l2_distances[upper_triangle]
                    sigma_ss = np.max(distances_flat)
                elif self.args.sigma_ss == 'knn_l2':
                    sigma_ss = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=x_ss,metric = "euclidean")


                gamma_mb =1/(2 * (sigma_mb ** 2))
                gamma_ss = 1 / (2 * (sigma_ss ** 2))
                K_model = rbf_kernel(x_model, y_model, gamma=gamma_mb)
                K_ss = rbf_kernel(x_ss, y_ss, gamma= gamma_ss)
                k=len(x_model) // max_size
                # if weights[0] ==0 or weights[1]==0:
                #     return weights[0] * K_model + weights[1] *K_ss
                return weights[0] * K_model + weights[1] * K_ss

            kernel_function = integrated_kernel
        else:
            raise ValueError(
                "Integrated features are not supported in MaxHerding strategy. Please use TEALExemplarsSelectionStrategy instead.")

        query_indices = []

        # Perform selection over multiple iterations
        for t in range(1):
            U_remaining = np.hstack([U_model, U_ss])

            if len(U_remaining) <= max_size:
                # If remaining unlabeled set is smaller than budget, select all
                selected_indices = np.arange(len(U_remaining))
            else:
                # Run kernel k-medoids to select representative points
                selected_indices = self.greedy_k_herding(U_remaining, max_size, kernel_function)

            # Get the original indices
            original_indices = U_indices[selected_indices]

            # Add to query indices
            query_indices.extend(original_indices)

            # Remove selected points from unlabeled set
            mask = np.ones(len(U_remaining), dtype=bool)
            mask[selected_indices] = False
            U_remaining = U_remaining[mask]
            U_indices = U_indices[mask]

            # Break if all points are selected or remaining are less than budget
            if len(U_remaining) == 0:
                break

        # Get remaining indices (not selected)
        n_samples = self.features.shape[0] if self.ss_method == 'model_based' else self.features_ss.shape[0]
        all_indices = np.arange(n_samples)
        remaining_indices = np.setdiff1d(all_indices, query_indices)

        # Shuffle the remaining indices
        np.random.shuffle(remaining_indices)

        # Save selected indices
        self.buffer_indices = query_indices
        
        # Save top 5 selected features_ss (if available and if we have at least 5)
        if hasattr(self, 'features_ss') and self.features_ss is not None and len(query_indices) >= 5:
            self.selected_features_ss = self.features_ss[query_indices[:5]]  # Top 5 selected
        elif hasattr(self, 'features_ss') and self.features_ss is not None:
            self.selected_features_ss = self.features_ss[query_indices]  # All selected if less than 5
        else:
            self.selected_features_ss = None
            
        # Save remaining features_ss (everything not selected) for KNN testing
        if hasattr(self, 'features_ss') and self.features_ss is not None:
            # Get remaining indices (not selected)
            remaining_indices_for_features = np.setdiff1d(np.arange(len(self.features_ss)), query_indices)
            if len(remaining_indices_for_features) > 0:
                self.remaining_features_ss = self.features_ss[remaining_indices_for_features]  # Remaining features for KNN testing
            else:
                self.remaining_features_ss = None
        else:
            self.remaining_features_ss = None

        # Concatenate query_indices with the shuffled remaining indices
        self.new_order = np.concatenate((query_indices, remaining_indices)).astype(int)

        return self.new_order.tolist()  # Convert to list for compatibility

    def save_episode_summary(self):
        """
        Save a summary of all weights for the current episode.
        Call this method at the end of each episode.
        """
        if not self.episode_weights:
            print("No weights to save for this episode.")
            return None
            
        # Determine the actual features type for weight saving
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": self.alpha,
            "episode": self.episode,
            "selection_strategy": "MaxHerding"
        }
        
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        summary_file = save_episode_weights_summary(
            episode_weights=self.episode_weights,
            episode_info=episode_info,
            exp_dir=exp_dir
        )
        
        # Clear episode weights for next episode
        self.episode_weights = []
        
        return summary_file

    def get_selected_indices(self):
        """
        Get the selected indices (query indices).

        Returns:
        --------
        List[int]
            Selected indices
        """
        return self.buffer_indices

    def get_sorted_order(self):
        """
        Get the complete sorted order.

        Returns:
        --------
        List[int]
            Complete sorted order
        """
        return self.new_order.tolist()





class ProbCoverExemplarsSelectionStrategy(TEALExemplarsSelectionStrategy):
    def __init__(self, args, device, extra_args=None):
        super().__init__(args, device, extra_args)
        self.args=args
        self.concatenated = args.concatenated
        self.concatenated = self.concatenated.strip().lower() == "true"
        print(f"concatenated is {self.concatenated}")
        self.ss_method = args.features_type
        self.dataset_name = args.dataset
        self.selection_strategy = args.sel_strategy
        if self.selection_strategy == 'probcover':
            self.alpha = args.alpha
            print(f"alpha is {self.alpha}")
        self.seed = args.seed
        self.weight_method = args.weight_method
        self.device = device
        self.features = None
        self.counts = np.array([])
        self.delta = args.delta
        self.integrated = args.integrated_features
        self.integrated = self.integrated.strip().lower() == "true"
        self.log_iterative = args.teal_type
        self.increase_factor = args.increase_factor
        self.cur_class = None
        self.episode = 0  # Track current episode
        self.episode_weights = []  # Store weights for current episode

    def make_sorted_indices(
            self, strategy: "SupervisedTemplate", data: AvalancheDataset
    ) -> List[int]:
        if len(list(data.targets.uniques)) < 1:
            return []
        d_func = cosine_distances
        cur_class = list(data.targets.uniques)[0]
        self.cur_class = cur_class
        if self.group_to_len.get(cur_class) is not None:
            return list(range(len(data)))
        else:
            self.group_to_len[cur_class] = cur_class
            self.episode = strategy.experience.current_experience  # Increment episode for new class
        self.init_features(data, strategy.model)
        max_size = strategy.plugins[1].storage_policy.max_size // len(strategy.plugins[1].storage_policy.seen_groups)
        # Create an array of ProbCover objects
        if self.integrated:
            num_probcovers = 2
        # else:
        #     num_probcovers = 1
        if self.integrated:
            self.weights = [self.alpha, 1 - self.alpha]  # Weights for each ProbCover object
        # else:
        #     weights = [1]
        # Assuming you have the original feature datasets before splitting
        # Initialize remaining_features and remaining_indices for each probcover
        remaining_features = []
        ss_method = self.ss_method
        integrated = self.integrated
        features = self.features
        features_ss = self.features_ss
        remaining_indices = calculate_remaining_indices(features, features_ss, integrated, remaining_features,
                                                        ss_method)
        # Number of ProbCover objects to use
        # Initialize
        self.new_order = self.apply_probcover_graph(d_func, max_size, num_probcovers, remaining_features,
                                                        remaining_indices, self.weights, cur_class)
        return self.new_order

    @staticmethod
    def optimal_delta(max_size: int = 1, X=None, num_log: int = 0, src_features=False, metric = 'cosine'):
        """
        Calculate the optimal delta value based on the current dataset.
        Calculate the average distance between all pairs of x k-nn.
        X is the number of features in the dataset divided by max_size
        :return: optimal delta value
        """
        # Choose an appropriate k value (can be adjusted based on dataset size)
        k = min(len(X) - 1, len(X) // max_size)
        # Initialize nearest neighbors model
        nn_model = NearestNeighbors(n_neighbors=k + 1, metric=metric)  # +1 because the point itself is included
        nn_model.fit(X)
        # Find k-nearest neighbors
        distances, _ = nn_model.kneighbors(X)
        # Remove the first column (distance to self is always 0)
        distances = distances[:, 1:]
        # Calculate median distance between points and their k nearest neighbors
        delta = np.median(distances)
        return delta

    def probcover_delta_geometric(self,X, B, metric='cosine'):
        """
        One scalar δ for a single embedding + class, based on the geometric
        blend of local scales; parameter-free except for your budget B.

        X : (n, d) array     # embeddings for ONE class in ONE embedding
        B : int              # replay budget for that class
        metric : 'cosine' | 'euclidean'

        Returns
        -------
        delta : float        # global δ
        """
        n = len(X)
        if n <= 1 or B <= 0:
            return 0.0

        # Target neighborhood size k ≈ n/B (capped into [1, n-1])
        k = max(1, min(n - 1, n // B))

        nn = NearestNeighbors(n_neighbors=k + 1, metric=metric).fit(X)

        # distances[:,0] is self (0); [:,1:] are the k nearest *others* (sorted in sklearn)
        distances, _ = nn.kneighbors(X, return_distance=True)
        d1 = float(np.median(distances[:, 1]))  # median 1-NN distance
        dk = float(np.median(distances[:, k]))  # k-th NN distance per point

        # Per-point geometric blend, then a robust global scalar via median
        delta = np.sqrt(d1 * dk)  # (n,)
        return delta
    def apply_probcover_graph(self, d_func, max_size, num_probcovers, remaining_features, remaining_indices, weights,
                              cur_class, num_log=0):
        n_samples = len(remaining_features[0])
        is_candidate = np.ones(n_samples, dtype=bool)
        query_indices = []
        # Pre-compute distance matrices for each probcover
        distance_matrices = []
        for j in range(num_probcovers):
            distances = d_func(remaining_features[j])
            distance_matrices.append(distances)
        # Create edge matrices (True if distance <= delta)
        edge_matrices = []

        for j in range(num_probcovers):
            if j == 0:  ##model based
                delta_knn = self.optimal_delta(max_size=max_size, num_log=num_log,
                                             X=remaining_features[0], src_features=True)
                delta_1nn = self.optimal_delta(max_size=len(remaining_features[0]), num_log=num_log,
                                                       X=remaining_features[0], src_features=True)
                if self.args.delta_mb=='1nn':
                    edges = distance_matrices[j] <= delta_1nn
                elif self.args.delta_mb=='knn':
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_mb=='knn_cross_ratio':
                    edges = distance_matrices[j] <= delta_knn*(weights[1]/weights[0])
                elif self.args.delta_mb=='median_cosine':
                   edges = distance_matrices[j] <=np.median(pairwise_distances(remaining_features[0], metric="cosine"))
                elif self.args.delta_mb=='geometric':
                    edges = distance_matrices[j] <= self.probcover_delta_geometric(remaining_features[0], max_size)
                elif self.args.delta_mb=='knn_based':
                    # Use k-NN based delta scaling
                    delta_knn = knn_based_delta(remaining_features[0], max_size, 'model_based')
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_mb.startswith('knn_based_'):
                    # Use k-NN based delta scaling with specific scale method
                    scale_method = self.args.delta_mb.split('_', 2)[2]  # Extract scale method after 'knn_based_'
                    delta_knn = knn_based_delta(remaining_features[0], max_size, 'model_based', scale_method=scale_method)
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_mb=='coverage_optimal':
                    # Delta that ensures each selected point covers approximately n/B points
                    n = len(remaining_features[0])
                    B = max_size
                    target_coverage = n / B  # How many points each selection should cover on average
                    
                    # Get all pairwise distances
                    cosine_dists = pairwise_distances(remaining_features[0], metric="cosine")
                    
                    # For each point, count how many neighbors it would have with different deltas
                    delta_candidates = np.percentile(cosine_dists.flatten(), [20, 30, 40, 50, 60, 70, 80, 90])
                    best_delta = delta_candidates[0]  # Default
                    min_error = float('inf')
                    
                    for delta in delta_candidates:
                        # Count neighbors for each point with this delta
                        neighbor_counts = (cosine_dists <= delta).sum(axis=1) - 1  # -1 to exclude self
                        avg_coverage = np.median(neighbor_counts)  # Use median to be robust to outliers
                        
                        # Find delta that gives coverage closest to target
                        error = abs(avg_coverage - target_coverage)
                        if error < min_error:
                            min_error = error
                            best_delta = delta
                    
                    edges = distance_matrices[j] <= best_delta
                elif self.args.delta_mb=='budget_scaled_median':
                    # Scale median distance based on budget ratio
                    median_dist = np.median(pairwise_distances(remaining_features[0], metric="cosine"))
                    budget_ratio = max_size / len(remaining_features[0])
                    
                    # Scale factor: larger budget -> smaller delta, smaller budget -> larger delta
                    # Use inverse relationship with some smoothing
                    scale_factor = 1.0 / (budget_ratio + 0.1)  # +0.1 prevents division by zero
                    
                    # Apply bounds to prevent extreme values
                    scale_factor = np.clip(scale_factor, 0.5, 2.0)  # Scale between 0.5x and 2x
                    
                    delta_scaled = median_dist * scale_factor
                    edges = distance_matrices[j] <= delta_scaled
                elif self.args.delta_mb=='median_1nn_half':
                    # Use median 1-NN distance divided by half
                    delta_1nn = self.optimal_delta(max_size=len(remaining_features[0]), num_log=num_log, X=remaining_features[0], src_features=True)
                    delta_half = delta_1nn / 2.0
                    edges = distance_matrices[j] <= delta_half
                else:
                    raise("Error! undefined delta model based")
            else:
                if self.args.delta_ss=='1nn':
                    delta_1nn = self.optimal_delta(max_size=len(remaining_features[1]), num_log=num_log, X=remaining_features[1])
                    edges = distance_matrices[j] <= delta_1nn
                elif self.args.delta_ss=='knn':
                    delta_knn= self.optimal_delta(max_size=max_size, num_log=num_log, X=remaining_features[1])
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_ss=='median_cosine':
                   edges = distance_matrices[j] <=np.median(pairwise_distances(remaining_features[1], metric="cosine"))
                elif self.args.delta_ss=='geometric':
                    edges = distance_matrices[j] <= self.probcover_delta_geometric(remaining_features[1], max_size)
                elif self.args.delta_ss=='knn_based':
                    # Use k-NN based delta scaling
                    delta_knn = knn_based_delta(remaining_features[1], max_size, self.ss_method)
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_ss.startswith('knn_based_'):
                    # Use k-NN based delta scaling with specific scale method
                    scale_method = self.args.delta_ss.split('_', 2)[2]  # Extract scale method after 'knn_based_'
                    delta_knn = knn_based_delta(remaining_features[1], max_size, self.ss_method, scale_method=scale_method)
                    edges = distance_matrices[j] <= delta_knn
                elif self.args.delta_ss=='coverage_optimal':
                    # Delta that ensures each selected point covers approximately n/B points
                    n = len(remaining_features[1])
                    B = max_size
                    target_coverage = n / B  # How many points each selection should cover on average
                    
                    # Get all pairwise distances
                    cosine_dists = pairwise_distances(remaining_features[1], metric="cosine")
                    
                    # For each point, count how many neighbors it would have with different deltas
                    delta_candidates = np.percentile(cosine_dists.flatten(), [20, 30, 40, 50, 60, 70, 80, 90])
                    best_delta = delta_candidates[0]  # Default
                    min_error = float('inf')
                    
                    for delta in delta_candidates:
                        # Count neighbors for each point with this delta
                        neighbor_counts = (cosine_dists <= delta).sum(axis=1) - 1  # -1 to exclude self
                        avg_coverage = np.median(neighbor_counts)  # Use median to be robust to outliers
                        
                        # Find delta that gives coverage closest to target
                        error = abs(avg_coverage - target_coverage)
                        if error < min_error:
                            min_error = error
                            best_delta = delta
                    
                    edges = distance_matrices[j] <= best_delta
                elif self.args.delta_ss=='budget_scaled_median':
                    # Scale median distance based on budget ratio
                    median_dist = np.median(pairwise_distances(remaining_features[1], metric="cosine"))
                    budget_ratio = max_size / len(remaining_features[1])
                    
                    # Scale factor: larger budget -> smaller delta, smaller budget -> larger delta
                    # Use inverse relationship with some smoothing
                    scale_factor = 1.0 / (budget_ratio + 0.1)  # +0.1 prevents division by zero
                    
                    # Apply bounds to prevent extreme values
                    scale_factor = np.clip(scale_factor, 0.5, 2.0)  # Scale between 0.5x and 2x
                    
                    delta_scaled = median_dist * scale_factor
                    edges = distance_matrices[j] <= delta_scaled
                elif self.args.delta_ss=='median_1nn_half':
                    # Use median 1-NN distance divided by half
                    delta_1nn = self.optimal_delta(max_size=len(remaining_features[1]), num_log=num_log, X=remaining_features[1])
                    delta_half = delta_1nn / 2.0
                    edges = distance_matrices[j] <= delta_half
                else:
                    raise("Error! undefined delta model based")
            edge_matrices.append(edges)

        # Initialize weights for each probcover
        k = calculate_weight(max_size, remaining_features, weights, weight_method=self.weight_method)
        print("weights of model based is:", weights[0])
        print("weights of self-supervised is:", weights[1])
        print("k value used for weight calculation:", k)
        
        # Save weights to file
        # Determine the actual features type for weight saving
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": getattr(self, 'alpha', 0.5),
            "episode": self.episode,
            "class_id": cur_class,
            "selection_strategy": self.selection_strategy
        }
        
        # Get experiment directory from args if available
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        
        # Save individual class weights
        weight_file = save_weights_to_file(
            weights=weights,
            weight_method=self.weight_method,
            episode_info=episode_info,
            exp_dir=exp_dir,
            class_id=cur_class
        )
        
        # Store weight data for episode summary
        weight_data = {
            "class_id": cur_class,
            "model_based_weight": float(weights[0]),
            "self_supervised_weight": float(weights[1]),
            "weight_ratio": float(weights[0] / weights[1]) if weights[1] != 0 else float('inf'),
            "k_value": int(k),
            "weight_method": self.weight_method,
            "timestamp": datetime.now().isoformat()
        }
        self.episode_weights.append(weight_data)
        # Perform selection for max_size samples
        self.probcover(edge_matrices, is_candidate, max_size, n_samples, num_probcovers, query_indices,
                       remaining_indices, weights)
        # self.only_selected_probcover(edge_matrices, is_candidate, max_size, n_samples, num_probcovers, query_indices,
        #                         remaining_indices, weights)
        if self.ss_method == 'model_based' or self.concatenated:
            all_indices = np.arange(len(self.features))
        else:
            all_indices = np.arange(len(self.features_ss))
        # Remove the query_indices from all_indices
        remaining_indices = np.setdiff1d(all_indices, query_indices)
        # Shuffle the remaining indices
        np.random.shuffle(remaining_indices)
        self.buffer_indices = query_indices
        
        # Save top 5 selected features_ss (if available and if we have at least 5)
        if hasattr(self, 'features_ss') and self.features_ss is not None and len(query_indices) >= 5:
            self.selected_features_ss = self.features_ss[query_indices[:5]]  # Top 5 selected
        elif hasattr(self, 'features_ss') and self.features_ss is not None:
            self.selected_features_ss = self.features_ss[query_indices]  # All selected if less than 5
        else:
            self.selected_features_ss = None
            
        # Save remaining features_ss (everything not selected) for KNN testing
        if hasattr(self, 'features_ss') and self.features_ss is not None:
            # Get remaining indices (not selected)
            remaining_indices_for_features = np.setdiff1d(np.arange(len(self.features_ss)), query_indices)
            if len(remaining_indices_for_features) > 0:
                self.remaining_features_ss = self.features_ss[remaining_indices_for_features]  # Remaining features for KNN testing
            else:
                self.remaining_features_ss = None
        else:
            self.remaining_features_ss = None
            
        # Concatenate query_indices with the shuffled remaining indices
        new_order = np.concatenate((query_indices, remaining_indices)).astype(int)
        return new_order

    def save_episode_summary(self):
        """
        Save a summary of all weights for the current episode.
        Call this method at the end of each episode.
        """
        print(f"DEBUG: Episode weights count: {len(self.episode_weights)}")
        if not self.episode_weights:
            print("No weights to save for this episode.")
            return None
            
        # Determine the actual features type for weight saving
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": getattr(self, 'alpha', 0.5),
            "episode": self.episode,
            "selection_strategy": self.selection_strategy
        }
        
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        summary_file = save_episode_weights_summary(
            episode_weights=self.episode_weights,
            episode_info=episode_info,
            exp_dir=exp_dir
        )
        
        # Clear episode weights for next episode
        self.episode_weights = []
        
        return summary_file

    def only_selected_probcover(self, edge_matrices, is_candidate, max_size, n_samples, num_probcovers, query_indices,
                                remaining_indices, weights):
        for i in range(max_size):
            all_utilities = np.zeros(n_samples)
            # Calculate utilities (out-degrees) for each probcover
            for j in range(num_probcovers):
                edges = edge_matrices[j]
                already_covered = ~is_candidate
                edges[:, already_covered] = False
                # Store updated edge matrix
                edge_matrices[j] = edges
                # Calculate utilities as out-degrees
                utilities = np.zeros(n_samples)
                utilities[is_candidate] = edges[is_candidate].sum(axis=1)
                # Combine with weights
                all_utilities += weights[j] * utilities

            idx = np.argmax(all_utilities)
            # Select the sample with highest combined utility
            # Add to query indices (using original indices)
            original_idx = remaining_indices[idx]
            query_indices.append(original_idx)
            is_candidate[idx] = False


    def probcover(self, edge_matrices, is_candidate, max_size, n_samples, num_probcovers, query_indices,
                  remaining_indices, weights):
        for i in range(max_size):
            all_utilities = np.zeros(n_samples)
            # all_utilities = np.zeros((2, n_samples))
            # Calculate utilities (out-degrees) for each probcover
            for j in range(num_probcovers):
                edges = edge_matrices[j]
                # Remove incoming edges for covered samples
                is_covered = edges[~is_candidate].any(axis=0)
                edges[:, is_covered] = False
                # Store updated edge matrix
                edge_matrices[j] = edges
                # Calculate utilities as out-degrees
                utilities = np.zeros(n_samples)
                utilities[is_candidate] = edges[is_candidate].sum(axis=1)
                # Combine with weights
                all_utilities += weights[j] * utilities
                # all_utilities[j] =  weights[j] *utilities
            idx = np.argmax(all_utilities)
            # Select the sample with highest combined utility
            # Add to query indices (using original indices)
            original_idx = remaining_indices[idx]
            query_indices.append(original_idx)
            # Mark as no longer a candidate
            is_candidate[idx] = False

    def knn_density(self, embeddings: np.ndarray, k: int = 5) -> np.ndarray:
        """
        Compute k-NN density scores for each point in `embeddings`.

        Density score for point i is defined as:
            rho_i = k / sum_{j=1}^k d_{i, j}

        where d_{i, j} is the distance from point i to its j-th nearest neighbor.

        Parameters
        ----------
        embeddings : np.ndarray, shape (n_samples, n_features)
            The learned representations / embedding vectors.
        k : int
            Number of neighbors to use for density estimation.

        Returns
        -------
        densities : np.ndarray, shape (n_samples,)
            Density score rho_i for each point.
        """
        # fit NN model (we add 1 because the first neighbor of each point is itself at distance 0)
        nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='auto', metric='cosine').fit(embeddings)
        distances, _ = nbrs.kneighbors(embeddings)

        # drop the zero-th column (distance to itself)
        knn_distances = distances[:, 1:]  # shape (n_samples, k)

        # compute density: k / sum of distances to k nearest neighbors
        densities = 1 / np.sum(knn_distances, axis=1)

        return densities

    def init_features(self, data, model):
        if self.integrated:
            TEALExemplarsSelectionStrategy.init_features(self, data, model)
            norms = np.linalg.norm(self.features, axis=1, keepdims=True)
            self.features = self.features / np.maximum(norms, 1)
            features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed, order=self.args.order, nvidia=self.args.nvidia)
            features = np.vstack(features)
            features = torch.tensor(features)
            self.features_ss = features.cpu().numpy()
            if self.args.buffer==300:
                self.visualize_features("Model Based", self.features)
                self.visualize_features(self.ss_method, self.features_ss)
        # elif self.ss_method == 'model_based':
        #     TEALExemplarsSelectionStrategy.init_features(self, data, model)
        #     norms = np.linalg.norm(self.features, axis=1, keepdims=True)
        #     self.features = self.features / np.maximum(norms, 1)
        # elif self.ss_method == 'dino' or self.ss_method == 'simclr' or self.ss_method == 'vicreg':
        #     features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed)
        #     features = np.vstack(features)
        #     features = torch.tensor(features)
        #     self.features_ss = features.cpu().numpy()

    def visualize_features(self, method, features, plot_support_vs_k=True):
        import matplotlib.pyplot as plt
        from sklearn.neighbors import NearestNeighbors
        import numpy as np
        import os

        os.makedirs(self.args.exp_dir, exist_ok=True)

        # Your original k's
        k_vals = [500 // d for d in (500,30, 15, 10, 8, 6, 5, 4, 3)]
        n_B = {500 // d: d for d in (500, 30, 15, 10, 8, 6, 5, 4, 3)}
        k_vals = [k for k in k_vals if k + 1 <= 500]

        # ----------------------------
        # 1) Histograms (unchanged)
        # ----------------------------
        fig, axes = plt.subplots(2, 4, figsize=(18, 10))
        axes = axes.flatten()

        # For the multiplicity-support-as-f(k) results
        support_k_summary = []  # rows: [k, mean_count, median_count, mean_norm, median_norm]

        for ax, k in zip(axes, k_vals):
            nn = NearestNeighbors(n_neighbors=k + 1, metric='cosine').fit(features)

            # distances & indices
            dists, idx = nn.kneighbors(features, return_distance=True)
            # drop self (first neighbor is self with distance 0)
            k_dist = dists[:, 1:].flatten()
            knn_idx = idx[:, 1:]  # shape (n, k)
            n_points = knn_idx.shape[0]

            # ---- your histogram plot ----
            counts, bin_edges = np.histogram(k_dist, bins=30)
            n, bins, patches = ax.hist(
                k_dist, bins=bin_edges, color="skyblue", edgecolor="black", alpha=0.7
            )

            # Median line + highlight median bin
            median_val = np.median(k_dist)
            ax.axvline(median_val, color='red', linestyle='--', linewidth=2,
                       label=f"Median = {median_val:.2f}")

            for i_bin in range(len(bin_edges) - 1):
                if bin_edges[i_bin] <= median_val < bin_edges[i_bin + 1]:
                    patches[i_bin].set_facecolor("red")
                    patches[i_bin].set_alpha(0.6)
                    patches[i_bin].set_linewidth(2.5)
                    patches[i_bin].set_edgecolor("darkred")
                    break

            ax.set_title(f"{k}-NN Distances\n (|B| = {n_B[k]} samples per class in the Buffer)", fontsize=13)
            ax.set_xlabel("Cosine Distance", fontsize=11)
            ax.set_ylabel("Frequency", fontsize=11)
            ax.grid(True, alpha=0.3)
            ax.legend(fontsize=9)

            # Save histogram counts
            np.savetxt(f"{self.args.exp_dir}/counts_k{k}_{method}.csv", counts, delimiter=",")

            # --------------------------------------------------------
            # 2) MULTIPLICITY SUPPORT (how much support)  <<< CHANGED
            #    For each u in N_k(i), count # of v in N_k(i) s.t. u ∈ N_k(v)
            #    Then sum over u. Normalize by k*(k-1) if you want [0,1].
            # --------------------------------------------------------
            sets = [set(row) for row in knn_idx]  # list of sets for fast membership
            multiplicity_support_count = np.zeros(n_points, dtype=int)

            # Complexity: ~O(n * k^2). With n≈500 and small/medium k, this is fine.
            for i in range(n_points):
                neighbors_i = knn_idx[i]
                set_i = sets[i]
                total_support = 0
                for u in neighbors_i:
                    # how many of i's neighbors list u?
                    # (includes checking whether v == u; fine, since u in N_k(u) is typically false unless ties/self-handling.
                    # If you want to EXCLUDE v == u explicitly, add: if v != u)
                    count_support_u = sum(1 for v in neighbors_i if u in sets[v])
                    total_support += count_support_u
                multiplicity_support_count[i] = total_support

            denom = k * (k - 1) if k > 1 else 1  # avoid div-by-zero for k=1
            multiplicity_support_norm = multiplicity_support_count / denom

            # Save per-point values for this k (both raw count and normalized)
            np.savetxt(
                f"{self.args.exp_dir}/snn_multiplicity_per_point_k{k}_{method}.csv",
                np.column_stack([multiplicity_support_count, multiplicity_support_norm]),
                delimiter=",",
                header="multiplicity_support_count,multiplicity_support_norm",
                comments=""
            )

            # Summaries for this k
            mean_count = float(np.mean(multiplicity_support_count))
            median_count = float(np.median(multiplicity_support_count))
            mean_norm = float(np.mean(multiplicity_support_norm))
            median_norm = float(np.median(multiplicity_support_norm))

            support_k_summary.append([k, mean_count, median_count, mean_norm, median_norm])

        # turn off empty subplots (if any)
        for ax in axes[len(k_vals):]:
            ax.axis("off")

        method_title = {'simclr': 'SimCLR', 'dino': 'DINO', 'vicreg': 'VICReg'}.get(method, method)
        fig.suptitle(
            f"{method_title} – Cosine K-NN Distances Histograms for Class {self.cur_class}\n",
            fontsize=18, y=0.97
        )
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        fig.savefig(
            f"{self.args.exp_dir}/knn_histograms_median_bold_{method}_class_{self.cur_class}_seed_{self.args.seed}.pdf",
            bbox_inches='tight', dpi=300
        )

        # --------------------------------------------
        # 3) Save and plot the MULTIPLICITY-vs-k curves
        # --------------------------------------------
        support_k_summary = np.array(support_k_summary, dtype=float)
        # CSV columns: k, mean_count, median_count, mean_norm, median_norm
        np.savetxt(
            f"{self.args.exp_dir}/snn_multiplicity_vs_k_{method}_class_{self.cur_class}_seed_{self.args.seed}.csv",
            support_k_summary,
            delimiter=",",
            header="k,mean_multiplicity_count,median_multiplicity_count,mean_multiplicity_norm,median_multiplicity_norm",
            comments=""
        )

        if plot_support_vs_k:
            ks = support_k_summary[:, 0]
            mean_norm = support_k_summary[:, 3]
            median_norm = support_k_summary[:, 4]

            plt.figure(figsize=(7, 5))
            plt.plot(ks, mean_norm, marker='o', label='Mean multiplicity (normalized)')
            plt.plot(ks, median_norm, marker='o', label='Median multiplicity (normalized)')
            plt.xlabel("k")
            plt.ylabel("Multiplicity support (normalized)")
            plt.title(f"SNN Multiplicity vs k – {method_title} (Class {self.cur_class})")
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.savefig(
                f"{self.args.exp_dir}/snn_multiplicity_vs_k_plot_{method}_class_{self.cur_class}_seed_{self.args.seed}.pdf",
                bbox_inches='tight', dpi=300
            )
            plt.close('all')
class SoloTEALExemplarsSelectionStrategy(TEALExemplarsSelectionStrategy):
    def __init__(self, args, device, extra_args=None):
        super().__init__(args, device, extra_args)
        self.ss_method = args.features_type
        self.dataset_name = args.dataset
        self.selection_strategy = args.sel_strategy
        self.seed = args.seed
        self.features_ss = None
        self.integrated = args.integrated_features
        self.integrated = self.integrated.strip().lower() == "true"
        self.weight_method = args.weight_method
        self.alpha = args.alpha
        self.args=args

    def init_features(self, data, model):

        if self.integrated:
            TEALExemplarsSelectionStrategy.init_features(self, data, model)
            norms = np.linalg.norm(self.features, axis=1, keepdims=True)
            self.features = self.features / np.maximum(norms, 1)
            features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed, order=self.args.order, nvidia=self.args.nvidia)
            features = np.vstack(features)
            features = torch.tensor(features)
            self.features_ss = features.cpu().numpy()
        # elif self.ss_method == 'model_based':
        #     TEALExemplarsSelectionStrategy.init_features(self, data, model)
        #
        # elif self.ss_method == 'dino' or self.ss_method == 'simclr' or self.ss_method == 'vicreg':
        #     features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed)
        #     features = np.vstack(features)
        #     features = torch.tensor(features)
        #     self.features_ss = features.cpu().numpy()

    # def init_clusters(self, ll, features=None):
    #     num_clusters = ll if ll / len(features) < 0.2 else ll // 5
    #     num_clusters = min(num_clusters, self.MAX_NUM_CLUSTERS)
    #     print(f'Clustering into {num_clusters} clusters...')
    #     return kmeans(features, num_clusters=num_clusters)

    # def make_sorted_indices_one_time(
    #         self, strategy: "SupervisedTemplate", data: AvalancheDataset
    # ) -> List[int]:
    #     """
    #     TEAL One-time
    #     """
    #     cur_class = list(data.targets.uniques)[0]
    #     if self.group_to_len.get(cur_class) is None:  # Starting a new buffer update
    #         self.start_buffer_update(storage_policy=strategy.plugins[1].storage_policy)
    #
    #     ll = self.group_to_len[cur_class]
    #
    #     if cur_class in self.seen_groups:  # Buffer was already updated
    #         return list(range(ll))
    #
    #     if cur_class in self.groups_in_buffer:
    #         self.seen_groups.add(cur_class)
    #         return list(range(ll))  # takes the top ll from the original TEAL "selected" order - Verified
    #
    #     self.init_features(data, strategy.model)
    #     clusters_results = {}
    #
    #     # Loop over both types of features
    #     for name, features in [('modelbased', self.features), ('ss', self.features_ss)]:
    #         if features is None:
    #             continue
    #
    #         clusters = self.init_clusters(ll, features=features)
    #         labels = np.copy(clusters)
    #
    #         # Count cluster sizes
    #         cluster_ids, cluster_sizes = np.unique(labels, return_counts=True)
    #         clusters_df = pd.DataFrame({
    #             'cluster_id': cluster_ids,
    #             'cluster_size': cluster_sizes
    #         })
    #         clusters_df['neg_cluster_size'] = -clusters_df['cluster_size']
    #
    #         # Filter and sort clusters
    #         clusters_df = clusters_df[clusters_df.cluster_size > self.MIN_CLUSTER_SIZE]
    #         clusters_df = clusters_df.sort_values('neg_cluster_size')
    #
    #         # Store results
    #         clusters_results[name] = {
    #             'clusters': clusters,
    #             'labels': labels,
    #             'clusters_df': clusters_df
    #         }
    #     selected = []
    #     typicalities = []
    #     weights = [self.alpha, 1 - self.alpha]
    #     max_size = strategy.plugins[1].storage_policy.max_size // len(strategy.plugins[1].storage_policy.seen_groups)
    #     remaining_features = []
    #     ss_method = self.ss_method
    #     integrated = self.integrated
    #     features = self.features
    #     features_ss = self.features_ss
    #     remaining_indices = calculate_remaining_indices(features, features_ss, integrated, remaining_features,
    #                                                     ss_method)
    #     calculate_weight(max_size=ll, remaining_features=remaining_features, weights=weights,
    #                      weight_method=self.weight_method)
    #     print(f"weights of model based is: {weights[0]}")
    #     print(f"weights of self-supervised is: {weights[1]}")
    #     for i in range(ll):
    #         typ_map = {}
    #         for c, (name, features) in enumerate([('modelbased', self.features), ('ss', self.features_ss)]):
    #             indices, j = [], 0
    #             while (len(indices) == 0):  # skip empty clusters
    #                 clusters_df = clusters_results[name]['clusters_df']
    #                 cluster = clusters_df.iloc[(i + j) % len(clusters_df)].cluster_id
    #                 indices = (clusters_results[name]['labels'] == cluster).nonzero()[0]
    #                 j += 1
    #             rel_feats = features[indices]
    #             # in case we have too small cluster, calculate density among half of the cluster
    #             typicality = calculate_typicality(rel_feats, min(self.K_NN, len(indices) // 2))
    #             weighted_vals = weights[c] * typicality
    #             typ_map.update({
    #                 idx: typ_map.get(idx, 0.0) + wv
    #                 for idx, wv in zip(indices, weighted_vals)
    #             })
    #
    #         all_indices = np.array(list(typ_map.keys()))
    #         all_scores = np.array([typ_map[idx] for idx in all_indices])
    #
    #         idx = all_indices[np.argmax(all_scores)]
    #         selected.append(idx)
    #
    #         # You may want to mark this idx as "used" in all label arrays:
    #         for name in clusters_results:
    #             clusters_results[name]['labels'][idx] = -1
    #
    #     self.seen_groups.add(cur_class)
    #     self.groups_in_buffer.add(cur_class)
    #     return selected


# class SoloHerdingExemplarsSelectionStrategy(HerdingSelectionStrategy):
#     def __init__(self, args, device, extra_args=None):
#         super().__init__(args, device)
#         self.ss_method = args.features_type
#         self.dataset_name = args.dataset
#         self.selection_strategy = args.sel_strategy
#         self.seed = args.seed
#         self.features_ss = None
#         self.features = None
#         self.features_ss = None
#
#     def init_features(self, data, model, device):
#         if self.integrated:
#             TEALExemplarsSelectionStrategy.init_features(self, data, model)
#             norms = np.linalg.norm(self.features, axis=1, keepdims=True)
#             self.features = self.features / np.maximum(norms, 1)
#             features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed, order=self.args.order, nvidia=self.args.nvidia)
#             features = np.vstack(features)
#             features = torch.tensor(features)
#             self.features_ss = features.cpu().numpy()
