import copy
import random
from typing import List, Callable, Tuple
import os
import json
from datetime import datetime

import math
import pandas as pd
from sklearn.metrics.pairwise import rbf_kernel, euclidean_distances
import numpy as np
import torch
from avalanche.benchmarks.utils import AvalancheDataset
from torch import nn
from torchvision import transforms, models

# from MERS.sampling_strategies.herding import HerdingSelectionStrategy
from MERS.sampling_strategies.teal import TEALExemplarsSelectionStrategy, calculate_typicality, kmeans
from sklearn.metrics.pairwise import cosine_distances
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances

import matplotlib.pyplot as plt

# -----------------------------
# Optional ILP backend (exact)
# -----------------------------
try:
    import pulp  # Exact ILP if available
    _HAVE_PULP = True
except Exception:
    _HAVE_PULP = False


# =========================
# Utilities (yours, kept)
# =========================

def knn_density(embeddings: np.ndarray, k: int = 5, metric: str = 'cosine') -> np.ndarray:
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='auto', metric=metric).fit(embeddings)
    distances, _ = nbrs.kneighbors(embeddings)
    knn_distances = distances[:, 1:]
    densities = 1 / np.sum(knn_distances, axis=1)
    return densities

def _as_2d_float(x):
    x = np.asarray(x, dtype=float)
    if x.ndim == 1: x = x.reshape(-1, 1)
    return x

def _l2_normalize_rows(X, eps=1e-12):
    return X / (np.linalg.norm(X, axis=1, keepdims=True) + eps)

def median_radial_density_fullSi(embeddings: np.ndarray) -> np.ndarray:
    X = _as_2d_float(embeddings)
    Xn = _l2_normalize_rows(X)
    c = Xn.mean(axis=0, keepdims=True)
    c /= (np.linalg.norm(c, axis=1, keepdims=True) + 1e-12)
    d2c = pairwise_distances(Xn, c, metric="cosine").ravel()
    D = pairwise_distances(Xn, Xn, metric="cosine")
    np.fill_diagonal(D, np.inf)
    n = len(Xn)
    out = np.empty(n, dtype=float)
    for i in range(n):
        idx = np.flatnonzero(d2c <= d2c[i])
        m = idx.size
        if m <= 1:
            out[i] = 0.0
            continue
        densities = np.empty(m, dtype=float)
        for t, s in enumerate(idx):
            drow = D[s, idx]
            ssum = np.sum(drow[np.isfinite(drow)])
            densities[t] = 1.0 / (ssum + 1e-12)
        out[i] = float(np.median(densities))
    return out

def knn_based_delta(X, max_size, embedding_type='unknown', metric='cosine', scale_method='cv'):
    from sklearn.neighbors import NearestNeighbors
    from sklearn.metrics import pairwise_distances
    from scipy.stats import skew
    k = min(len(X) - 1, len(X) // max_size)
    nn_model = NearestNeighbors(n_neighbors=k + 1, metric=metric).fit(X)
    distances, _ = nn_model.kneighbors(X)
    knn_distances = distances[:, 1:]
    base_delta = np.median(knn_distances)
    all_distances = pairwise_distances(X, metric=metric)
    upper_triangle = np.triu_indices_from(all_distances, k=1)
    distances_flat = all_distances[upper_triangle]
    mean_dist = np.mean(distances_flat)
    std_dist = np.std(distances_flat)
    if scale_method == 'cv':
        cv = std_dist / (mean_dist + 1e-8)
        scale_factor = cv
    elif scale_method == 'skewness':
        distance_skewness = skew(distances_flat)
        scale_factor = (1.0 / (1.0 + distance_skewness)) if distance_skewness > 0 else (1.0 + abs(distance_skewness))
    elif scale_method == 'cv_skewness':
        cv = std_dist / (mean_dist + 1e-8)
        distance_skewness = skew(distances_flat)
        skew_scale = (1.0 / (1.0 + distance_skewness)) if distance_skewness > 0 else (1.0 + abs(distance_skewness))
        scale_factor = cv * skew_scale
    elif scale_method == 'inverse_mean':
        scale_factor = 1.0 / (mean_dist + 1e-8)
    elif scale_method == 'percentile_25':
        p25 = np.percentile(distances_flat, 25)
        scale_factor = p25
    else:
        cv = std_dist / (mean_dist + 1e-8)
        scale_factor = cv
    final_delta = base_delta * scale_factor
    return final_delta


def calculate_weight(max_size, remaining_features, weights, weight_method='ratio_median_knn_density_k_1'):
    k = min(len(remaining_features[0]) - 1, len(remaining_features[0]) // max_size)
    if weight_method == 'ratio_median_knn_density_k_1':
        weights[0] = np.median(knn_density(remaining_features[0], k)) / np.median(knn_density(remaining_features[0], 1))
        weights[1] = np.median(knn_density(remaining_features[1], k)) / np.median(knn_density(remaining_features[1], 1))
        print("The weight method is ratio_median_knn_density_k_1")
    if weight_method == 'heuristic':
        weights[0] = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=max_size, X=remaining_features[1])
        weights[1] = ProbCoverExemplarsSelectionStrategy.optimal_delta(max_size=len(remaining_features[0]), X=remaining_features[0])
    if weight_method == 'importance_sampling':
        weights[0] = knn_density(remaining_features[0], k)
        weights[1] = knn_density(remaining_features[1], k)
        print("The weight method is importance_sampling_ess")
    if weight_method == 'euclidean_ratio_median_knn_density_k_1':
        metric = "euclidean"
        weights[0] = np.median(knn_density(remaining_features[0], k, metric)) / np.median(knn_density(remaining_features[0], 1, metric))
        weights[1] = np.median(knn_density(remaining_features[1], k, metric)) / np.median(knn_density(remaining_features[1], 1, metric))
        print("The weight method is euclidean_ratio_median_knn_density_k_1")
    elif weight_method == 'ratio_median_knn_density_1_k':
        weights[0] = np.median(knn_density(remaining_features[0], 1)) / np.median(knn_density(remaining_features[0], k))
        weights[1] = np.median(knn_density(remaining_features[1], 1)) / np.median(knn_density(remaining_features[1], k))
        print("The weight method is ratio_median_knn_density_1_k")
    elif weight_method == 'median_knn_density_knn':
        weights[0] = np.median(knn_density(remaining_features[0], k))
        weights[1] = np.median(knn_density(remaining_features[1], k))
        print("The weight method is median_knn_density_knn")
    elif weight_method == 'euclidean_median_knn_density_knn':
        metric = "euclidean"
        weights[0] = np.median(knn_density(remaining_features[0], k, metric))
        weights[1] = np.median(knn_density(remaining_features[1], k, metric))
    elif weight_method == 'median_knn_density_1':
        weights[0] = np.median(knn_density(remaining_features[0], 1))
        weights[1] = np.median(knn_density(remaining_features[1], 1))
        print("The weight method is median_knn_density_1")
    elif weight_method == 'inverse_median_knn_density_knn':
        weights[0] = 1 / np.median(knn_density(remaining_features[0], k))
        weights[1] = 1 / np.median(knn_density(remaining_features[1], k))
        print("The weight method is inverse_median_knn_density_knn")
    elif weight_method == 'inverse_median_knn_density_1':
        weights[0] = 1 / np.median(knn_density(remaining_features[0], 1))
        weights[1] = 1 / np.median(knn_density(remaining_features[1], 1))
        print("The weight method is inverse_median_knn_density_1")
    elif weight_method == 'ratio_median_knn_density_k_1_only_supervised':
        weights[0] = np.median(knn_density(remaining_features[0], k)) / np.median(knn_density(remaining_features[0], 1))
        weights[1] = 1
        print("The weight method is ratio_median_knn_density_k_1_only_supervised")
    elif weight_method == 'ratio_median_knn_density_1_k_only_supervised':
        weights[0] = np.median(knn_density(remaining_features[0], 1)) / np.median(knn_density(remaining_features[0], k))
        weights[1] = 1
        print("The weight method is ratio_median_knn_density_1_k_only_supervised")
    elif weight_method == 'knn_all':
        weights[0] = np.median(knn_density(remaining_features[0], len(remaining_features[0]) - 1))
        weights[1] = np.median(knn_density(remaining_features[1], len(remaining_features[1]) - 1))
        print("The weight method is median_cosine")
    elif weight_method == 'median_knn_density_radial_to_centroid':
        w0 = median_radial_density_fullSi(remaining_features[0])
        w1 = median_radial_density_fullSi(remaining_features[1])
        weights[0] = float(np.median(w0))
        weights[1] = float(np.median(w1))
        print("The weight method is median_knn_density_radial_to_centroid (cosine, adaptive k)")
    elif weight_method == 'adaptive_entropy':
        try:
            eff1 = herding_effectiveness(remaining_features[0], max_size)
            eff2 = herding_effectiveness(remaining_features[1], max_size)
            total_eff = eff1 + eff2
            if total_eff > 0:
                weights[0] = eff1 / total_eff
                weights[1] = eff2 / total_eff
            print(f"The weight method is herding_specific_balance: {weights}")
        except Exception as e:
            print(f"Herding specific balance failed: {e}, using equal weights")
    elif weight_method == 'mean_approximation_quality':
        def mean_approx_quality(X, max_size):
            true_mean = np.mean(X, axis=0)
            subset_sizes = [1, max_size // 4, max_size // 2, max_size]
            best_approx = float('inf')
            for size in subset_sizes:
                if size >= len(X) or size < 1:
                    continue
                for _ in range(min(20, len(X))):
                    indices = np.random.choice(len(X), size, replace=False)
                    subset_mean = np.mean(X[indices], axis=0)
                    dist = np.linalg.norm(subset_mean - true_mean)
                    best_approx = min(best_approx, dist)
            return 1.0 / (best_approx + 1e-8)
        try:
            qual1 = mean_approx_quality(remaining_features[0], max_size)
            qual2 = mean_approx_quality(remaining_features[1], max_size)
            total_qual = qual1 + qual2
            if total_qual > 0:
                weights[0] = qual1 / total_qual
                weights[1] = qual2 / total_qual
            print(f"The weight method is mean_approximation_quality: {weights}")
        except Exception as e:
            print(f"Mean approximation quality failed: {e}, using equal weights")
    elif weight_method == 'spread_vs_compactness':
        def spread_compactness_ratio(X):
            center = np.mean(X, axis=0)
            distances_to_center = np.linalg.norm(X - center, axis=1)
            compactness = np.mean(distances_to_center)
            X_sample = X
            pairwise_dists = []
            for i in range(len(X_sample)):
                for j in range(i + 1, len(X_sample)):
                    pairwise_dists.append(np.linalg.norm(X_sample[i] - X_sample[j]))
            spread = np.mean(pairwise_dists) if pairwise_dists else 1.0
            return spread / (compactness + 1e-8)
        try:
            ratio1 = spread_compactness_ratio(remaining_features[0])
            ratio2 = spread_compactness_ratio(remaining_features[1])
            total_ratio = ratio1 + ratio2
            if total_ratio > 0:
                weights[0] = ratio1 / total_ratio
                weights[1] = ratio2 / total_ratio
            print(f"The weight method is spread_vs_compactness: {weights}")
        except Exception as e:
            print(f"Spread vs compactness failed: {e}, using equal weights")
    elif weight_method == '':
        print("The weight method is not set, using args weights")
    return k


def save_weights_to_file(weights, weight_method, episode_info, exp_dir, class_id=None, task_id=None):
    weights_dir = os.path.join(exp_dir, "weights")
    os.makedirs(weights_dir, exist_ok=True)
    weight_data = {
        "timestamp": datetime.now().isoformat(),
        "episode_info": episode_info,
        "weight_method": weight_method,
        "model_based_weight": float(weights[0]),
        "self_supervised_weight": float(weights[1]),
        "class_id": class_id,
        "task_id": task_id,
        "total_weight": float(weights[0] + weights[1]),
        "weight_ratio": float(weights[0] / weights[1]) if weights[1] != 0 else float('inf'),
        "is_nvidia": episode_info.get('features_type', '').endswith('_nvidia')
    }
    if class_id is not None:
        filename = f"weights_class_{class_id}_episode_{episode_info.get('episode', 'unknown')}.json"
    elif task_id is not None:
        filename = f"weights_task_{task_id}_episode_{episode_info.get('episode', 'unknown')}.json"
    else:
        filename = f"weights_episode_{episode_info.get('episode', 'unknown')}.json"
    filepath = os.path.join(weights_dir, filename)
    with open(filepath, 'w') as f:
        json.dump(weight_data, f, indent=2)
    print(f"Weights saved to: {filepath}")
    print(f"Model-based weight: {weights[0]:.6f}")
    print(f"Self-supervised weight ({episode_info.get('features_type', 'unknown')}): {weights[1]:.6f}")
    print(f"Weight ratio (mb/ss): {weight_data['weight_ratio']:.6f}")
    if weight_data['is_nvidia']:
        print(f"Using NVIDIA implementation: {episode_info.get('features_type', 'unknown')}")
    return filepath

def save_episode_weights_summary(episode_weights, episode_info, exp_dir):
    weights_dir = os.path.join(exp_dir, "weights")
    os.makedirs(weights_dir, exist_ok=True)
    mb_weights = [w['model_based_weight'] for w in episode_weights]
    ss_weights = [w['self_supervised_weight'] for w in episode_weights]
    weight_ratios = [w['weight_ratio'] for w in episode_weights if w['weight_ratio'] != float('inf')]
    def safe_stats(data, name):
        if len(data) == 0:
            return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0, "median": 0.0}
        return {
            "mean": float(np.mean(data)),
            "std": float(np.std(data)),
            "min": float(np.min(data)),
            "max": float(np.max(data)),
            "median": float(np.median(data))
        }
    episode_summary = {
        "timestamp": datetime.now().isoformat(),
        "episode_info": episode_info,
        "total_classes": len(episode_weights),
        "class_weights": episode_weights,
        "statistics": {
            "model_based": safe_stats(mb_weights, "model_based"),
            "self_supervised": safe_stats(ss_weights, "self_supervised"),
            "weight_ratios": safe_stats(weight_ratios, "weight_ratios")
        }
    }
    filename = f"episode_weights_summary_episode_{episode_info.get('episode', 'unknown')}.json"
    filepath = os.path.join(weights_dir, filename)
    with open(filepath, 'w') as f:
        json.dump(episode_summary, f, indent=2)
    print(f"Episode weights summary saved to: {filepath}")
    print(f"Episode {episode_info.get('episode', 'unknown')} - {len(episode_weights)} classes processed")
    print(f"Average MB weight: {episode_summary['statistics']['model_based']['mean']:.6f}")
    print(f"Average SS weight: {episode_summary['statistics']['self_supervised']['mean']:.6f}")
    return filepath


def pretrained_representations(data, dataset, ss_method, seed, order=None, nvidia=None):
    class_indicator = list(data.targets.uniques)[0]
    if not order:
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/regular_order/representations_trained_{dataset}_{ss_method}_seed_{seed}/{dataset}_{ss_method}_all_{class_indicator}.npy',
            allow_pickle=True)
    else:
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/order_seed_{order}/representations_trained_{dataset}_{ss_method}_seed_{seed}/{dataset}_{ss_method}_all_{class_indicator}.npy',
            allow_pickle=True)
        print(f"Loading representations for dataset {dataset} with order {order} and seed {seed}")
    if nvidia:
        if dataset == "tinyimg":
            dataset = "tinyimagenet"
        obj = np.load(
            f'/cs/labs/daphna/danit.yanowsky/CL/Plugins/dino_nvidia/representations_trained_{dataset}_dinov2_torch_dinov2_vitb14_seed_0/{dataset}_dinov2_torch_dinov2_vitb14_all_{class_indicator}.npy',
            allow_pickle=True)
        print(f"Loading representations for dataset {dataset} with nvidia dino")
    features = obj.item()['features']
    norms = np.linalg.norm(features, axis=1, keepdims=True)
    features = features / np.maximum(norms, 1)
    print(f"features of class {class_indicator} loaded, length:{len(features)}")
    return features


def calculate_remaining_indices(features, features_ss, integrated, remaining_features, ss_method):
    if ss_method == 'model_based' or integrated:
        print(f"{ss_method} features are used, {integrated} integrated features are used")
        remaining_features.append([features.copy() for features in features])
        remaining_indices = list(np.arange(len(features)))
    if (ss_method == 'dino' or ss_method == 'simclr' or ss_method == 'vicreg'):
        print(f"{ss_method} features are used")
        remaining_features.append([features.copy() for features in features_ss])
        remaining_indices = list(np.arange(len(features_ss)))
    return remaining_indices


# =====================================================
# Your existing selection strategies (kept as-is)
#   - MaxHerding
#   - ProbCoverExemplarsSelectionStrategy
#   (omitted here for brevity if you already have them in your repo)
#   If you need me to paste them again verbatim, say the word.
# =====================================================


# =======================================
# New helpers for budgeted coverage model
# =======================================

def _components_under_delta(X: np.ndarray, delta: float, metric: str = "cosine") -> Tuple[np.ndarray, int, np.ndarray]:
    """
    Build a threshold graph on X with edges (i,j) if dist(i,j) <= delta.
    Return:
      comp_id[i] in {0..C-1},  C (num components),  comp_sizes[comp_id]
    """
    if len(X) == 0:
        return np.array([], dtype=int), 0, np.array([], dtype=int)

    D = pairwise_distances(X, metric=metric)
    A = (D <= delta).astype(np.uint8)
    np.fill_diagonal(A, 1)  # self

    n = A.shape[0]
    comp_id = -np.ones(n, dtype=int)
    cur = 0
    for i in range(n):
        if comp_id[i] != -1:
            continue
        # BFS/DFS
        stack = [i]
        comp_id[i] = cur
        while stack:
            u = stack.pop()
            nbrs = np.where(A[u] == 1)[0]
            for v in nbrs:
                if comp_id[v] == -1:
                    comp_id[v] = cur
                    stack.append(v)
        cur += 1

    # sizes
    sizes = np.bincount(comp_id, minlength=cur)
    return comp_id, cur, sizes


class _BudgetedEdgeCoverageGreedy:
    """
    Greedy (1-1/e)-approximation for *weighted* budgeted edge coverage.
    Edges are (u, v, idx) where u is component id in E1, v in E2, idx is original sample index.
    """
    def __init__(self, budget, comp_sizes_e1, comp_sizes_e2, w_e1=1.0, w_e2=1.0):
        self.budget = int(budget)
        self.comp_sizes_e1 = np.asarray(comp_sizes_e1, dtype=float)
        self.comp_sizes_e2 = np.asarray(comp_sizes_e2, dtype=float)
        self.w_e1 = float(w_e1)
        self.w_e2 = float(w_e2)

    def select(self, uv_edges: np.ndarray, n_left: int, n_right: int):
        if uv_edges.size == 0 or self.budget <= 0:
            return []
        u = uv_edges[:, 0].astype(np.int64)
        v = uv_edges[:, 1].astype(np.int64)
        idx = uv_edges[:, 2].astype(np.int64)

        coveredL = np.zeros(n_left, dtype=bool)
        coveredR = np.zeros(n_right, dtype=bool)
        alive = np.ones(len(uv_edges), dtype=bool)
        chosen = []

        for _ in range(self.budget):
            if not np.any(alive): break
            gains = np.zeros(len(uv_edges), dtype=float)
            a = np.where(alive)[0]
            for j in a:
                g = 0.0
                if not coveredL[u[j]]:
                    g += self.w_e1 * self.comp_sizes_e1[u[j]]
                if not coveredR[v[j]]:
                    g += self.w_e2 * self.comp_sizes_e2[v[j]]
                gains[j] = g
            j = int(np.argmax(gains))
            if gains[j] <= 0: break
            coveredL[u[j]] = True
            coveredR[v[j]] = True
            chosen.append(int(idx[j]))
            alive[j] = False
        return chosen


class _BudgetedEdgeCoverageILP:
    """
    Exact ILP for *weighted* budgeted edge coverage.
    Requires PuLP. Falls back to greedy if PuLP unavailable (handled by caller).
    """
    def __init__(self, budget: int, comp_sizes_e1=None, comp_sizes_e2=None,
                 w_e1=1.0, w_e2=1.0, time_limit: int = None):
        self.budget = int(budget)
        self.time_limit = time_limit
        self.comp_sizes_e1 = np.array(comp_sizes_e1, dtype=float)
        self.comp_sizes_e2 = np.array(comp_sizes_e2, dtype=float)
        self.w_e1 = float(w_e1)
        self.w_e2 = float(w_e2)

    def select(self, uv_edges: np.ndarray, n_left: int, n_right: int):
        if uv_edges.size == 0 or self.budget <= 0:
            return []
        u = uv_edges[:, 0].astype(int)
        v = uv_edges[:, 1].astype(int)
        idx = uv_edges[:, 2].astype(int)
        m = len(uv_edges)

        prob = pulp.LpProblem("WeightedBudgetedCoverage", pulp.LpMaximize)
        x = pulp.LpVariable.dicts("x", range(m), cat="Binary")
        y = pulp.LpVariable.dicts("y", range(n_left), cat="Binary")
        z = pulp.LpVariable.dicts("z", range(n_right), cat="Binary")

        # Budget
        prob += pulp.lpSum([x[i] for i in range(m)]) <= self.budget

        # Cover constraints (activate y_u or z_v only if some chosen edge touches it)
        # More compact linearization:
        # y_u <= sum_{i: u_i = u} x_i ; z_v <= sum_{i: v_i = v} x_i
        for u_id in range(n_left):
            prob += y[u_id] <= pulp.lpSum([x[i] for i in range(m) if u[i] == u_id])
        for v_id in range(n_right):
            prob += z[v_id] <= pulp.lpSum([x[i] for i in range(m) if v[i] == v_id])

        # Weighted objective
        prob += (
            self.w_e1 * pulp.lpSum([self.comp_sizes_e1[j] * y[j] for j in range(n_left)]) +
            self.w_e2 * pulp.lpSum([self.comp_sizes_e2[j] * z[j] for j in range(n_right)])
        )

        solver = pulp.PULP_CBC_CMD(msg=False, timeLimit=self.time_limit)
        prob.solve(solver)

        selected = [idx[i] for i in range(m) if pulp.value(x[i]) > 0.5]
        return selected


# ======================================================
# NEW: BudgetedCoverageSelectionStrategy (weighted 2-view)
# ======================================================

class BudgetedCoverageSelectionStrategy(TEALExemplarsSelectionStrategy):
    """
    Weighted budgeted coverage over two embeddings (model-based + ss).
    - Build δ-ball components independently in each embedding.
    - Each original sample i induces an edge (comp_e1[i], comp_e2[i], i).
    - Select at most B edges (samples) to maximize:
        w_e1 * sum_{covered comps in E1} |comp|
      + w_e2 * sum_{covered comps in E2} |comp|
    """
    def __init__(self, args, device, extra_args=None):
        super().__init__(args, device, extra_args)
        self.args = args
        self.device = device
        self.ss_method = args.features_type
        self.dataset_name = args.dataset
        self.integrated = args.integrated_features
        self.integrated = str(self.integrated).strip().lower() == "true"
        self.alpha = getattr(args, "alpha", 0.5)
        self.weight_method = getattr(args, "weight_method", "ratio_median_knn_density_k_1")
        self.delta_mb_rule = getattr(args, "delta_mb", "median_cosine")
        self.delta_ss_rule = getattr(args, "delta_ss", "median_cosine")
        self.use_ilp = bool(getattr(args, "use_ilp", False) and _HAVE_PULP)
        self.ilp_timelimit = int(getattr(args, "ilp_timelimit", 60))
        self.features = None
        self.features_ss = None
        self.buffer_indices = []
        self.new_order = []
        self.group_to_len = {}
        self.episode = 0
        self.episode_weights = []

    # --------- helpers from ProbCover you already had ----------
    @staticmethod
    def optimal_delta(max_size: int = 1, X=None, metric='cosine'):
        k = min(len(X) - 1, max(1, len(X) // max_size))
        nn_model = NearestNeighbors(n_neighbors=k + 1, metric=metric)
        nn_model.fit(X)
        distances, _ = nn_model.kneighbors(X)
        distances = distances[:, 1:]
        return float(np.median(distances))

    def probcover_delta_geometric(self, X, B, metric='cosine'):
        n = len(X)
        if n <= 1 or B <= 0: return 0.0
        k = max(1, min(n - 1, n // B))
        nn = NearestNeighbors(n_neighbors=k + 1, metric=metric).fit(X)
        distances, _ = nn.kneighbors(X, return_distance=True)
        d1 = float(np.median(distances[:, 1]))
        dk = float(np.median(distances[:, k]))
        return float(np.sqrt(d1 * dk))
    # -----------------------------------------------------------

    def init_features(self, data, model):
        if not self.integrated:
            raise ValueError("BudgetedCoverageSelectionStrategy requires integrated features (two embeddings).")
        TEALExemplarsSelectionStrategy.init_features(self, data, model)
        norms = np.linalg.norm(self.features, axis=1, keepdims=True)
        self.features = self.features / np.maximum(norms, 1)
        feats = pretrained_representations(
            data, self.dataset_name, self.ss_method, seed=self.args.seed,
            order=getattr(self.args, 'order', None), nvidia=getattr(self.args, 'nvidia', None)
        )
        feats = np.vstack(feats)
        feats = torch.tensor(feats)
        self.features_ss = feats.cpu().numpy()

    def _delta_from_rule(self, rule: str, X: np.ndarray, B: int, metric: str, other_weight_ratio: float = 1.0):
        if rule == '1nn':
            return self.optimal_delta(max_size=len(X), X=X, metric=metric)
        elif rule == 'knn':
            return self.optimal_delta(max_size=B, X=X, metric=metric)
        elif rule == 'geometric':
            return self.probcover_delta_geometric(X, B, metric=metric)
        elif rule == 'median_l2':
            return float(np.median(pairwise_distances(X, metric='euclidean')))
        elif rule == 'median_cosine':
            return float(np.median(pairwise_distances(X, metric='cosine')))
        elif rule == 'knn_based':
            return float(knn_based_delta(X, B, scale_method='cv'))
        elif rule.startswith('knn_based_'):
            scale_method = rule.split('_', 2)[2]
            return float(knn_based_delta(X, B, scale_method=scale_method))
        elif rule == 'knn_cross_ratio':
            # same as your ProbCover variant
            base = self.optimal_delta(max_size=B, X=X, metric=metric)
            return float(base * other_weight_ratio)
        elif rule == 'coverage_optimal':
            # grid search over percentiles to match target n/B median coverage
            n = len(X)
            target = n / max(1, B)
            D = pairwise_distances(X, metric=metric)
            cand = np.percentile(D.flatten(), [20, 30, 40, 50, 60, 70, 80, 90])
            best, best_err = cand[0], float('inf')
            for dlt in cand:
                counts = (D <= dlt).sum(axis=1) - 1
                med = np.median(counts)
                err = abs(med - target)
                if err < best_err:
                    best, best_err = dlt, err
            return float(best)
        elif rule == 'budget_scaled_median':
            med = float(np.median(pairwise_distances(X, metric=metric)))
            ratio = max(1, B) / max(1, len(X))
            scale = 1.0 / (ratio + 0.1)
            scale = float(np.clip(scale, 0.5, 2.0))
            return float(med * scale)
        elif rule == 'median_1nn_half':
            d_1nn = self.optimal_delta(max_size=len(X), X=X, metric=metric)
            return float(d_1nn / 2.0)
        else:
            raise ValueError(f"Unknown delta rule: {rule}")

    def make_sorted_indices(self, strategy: "SupervisedTemplate", data: AvalancheDataset) -> List[int]:
        if len(list(data.targets.uniques)) < 1:
            return []
        cur_class = list(data.targets.uniques)[0]
        if self.group_to_len.get(cur_class) is not None:
            return list(range(len(data)))
        else:
            self.group_to_len[cur_class] = cur_class
            self.episode = strategy.experience.current_experience

        # Load features
        self.init_features(data, strategy.model)
        n = self.features.shape[0]
        B = strategy.plugins[1].storage_policy.max_size // max(1, len(strategy.plugins[1].storage_policy.seen_groups))

        # Compute per-embedding weights (your method)
        remaining_features = [self.features, self.features_ss]
        weights = [self.alpha, 1 - self.alpha]
        k_val = calculate_weight(max_size=B, remaining_features=remaining_features, weights=weights,
                                 weight_method=self.weight_method)
        w_e1, w_e2 = float(weights[0]), float(weights[1])
        print(f"[BudgetedCoverage] weights: E1={w_e1:.6f}, E2={w_e2:.6f}; k used for weight calc: {k_val}")

        # Deltas for each embedding
        # If you use 'knn_cross_ratio' for E1, pass ratio (w_e2/w_e1), etc.
        other_ratio = (w_e2 / (w_e1 + 1e-12))
        delta_e1 = self._delta_from_rule(self.delta_mb_rule, self.features, B, metric='cosine',
                                         other_weight_ratio=other_ratio)
        delta_e2 = self._delta_from_rule(self.delta_ss_rule, self.features_ss, B, metric='cosine',
                                         other_weight_ratio=1.0)  # typically no cross-ratio here
        print(f"[BudgetedCoverage] delta_e1={delta_e1:.6f}, delta_e2={delta_e2:.6f}")

        # Build δ-ball components in each embedding
        comp1, n1, sizes1 = _components_under_delta(self.features, delta_e1, metric='cosine')
        comp2, n2, sizes2 = _components_under_delta(self.features_ss, delta_e2, metric='cosine')

        # Each sample i ⇒ an edge (comp1[i], comp2[i], i)
        uv_edges = np.column_stack([comp1, comp2, np.arange(n, dtype=int)]).astype(int)

        # Choose solver
        if self.use_ilp:
            print("[BudgetedCoverage] Using ILP (exact).")
            solver = _BudgetedEdgeCoverageILP(
                budget=B, comp_sizes_e1=sizes1, comp_sizes_e2=sizes2,
                w_e1=w_e1, w_e2=w_e2, time_limit=self.ilp_timelimit
            )
            chosen = solver.select(uv_edges, n1, n2)
        else:
            if getattr(self.args, "use_ilp", False) and not _HAVE_PULP:
                print("[BudgetedCoverage] PuLP not found. Falling back to greedy.")
            solver = _BudgetedEdgeCoverageGreedy(
                budget=B, comp_sizes_e1=sizes1, comp_sizes_e2=sizes2,
                w_e1=w_e1, w_e2=w_e2
            )
            chosen = solver.select(uv_edges, n1, n2)

        chosen = list(map(int, chosen))
        rest = np.setdiff1d(np.arange(n), np.array(chosen, dtype=int))
        np.random.shuffle(rest)
        self.buffer_indices = chosen
        self.new_order = np.concatenate([np.array(chosen, dtype=int), rest]).astype(int)

        # Persist per-class weight json (same as your pattern)
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.args.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": getattr(self.args, 'alpha', 0.5),
            "episode": self.episode,
            "class_id": cur_class,
            "selection_strategy": "BudgetedCoverageSelectionStrategy"
        }
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        _ = save_weights_to_file(
            weights=[w_e1, w_e2],
            weight_method=self.weight_method,
            episode_info=episode_info,
            exp_dir=exp_dir,
            class_id=cur_class
        )
        self.episode_weights.append({
            "class_id": cur_class,
            "model_based_weight": float(w_e1),
            "self_supervised_weight": float(w_e2),
            "weight_ratio": float(w_e1 / w_e2) if w_e2 != 0 else float('inf'),
            "k_value": int(k_val),
            "weight_method": self.weight_method,
            "timestamp": datetime.now().isoformat()
        })

        # Also persist meta for debug
        os.makedirs(os.path.join(exp_dir, "weights"), exist_ok=True)
        meta = {
            "timestamp": datetime.now().isoformat(),
            "class_id": cur_class,
            "episode": self.episode,
            "strategy": "BudgetedCoverageWeighted",
            "delta_e1": float(delta_e1),
            "delta_e2": float(delta_e2),
            "num_comp_e1": int(n1),
            "num_comp_e2": int(n2),
            "weights": {"embedding1": w_e1, "embedding2": w_e2},
            "selected_count": int(len(chosen)),
            "use_ilp": bool(self.use_ilp)
        }
        with open(os.path.join(exp_dir, "weights", f"budgeted_coverage_class_{cur_class}_ep_{self.episode}.json"), "w") as f:
            json.dump(meta, f, indent=2)

        return self.new_order.tolist()

    def save_episode_summary(self):
        if not self.episode_weights:
            print("No weights to save for this episode.")
            return None
        actual_features_type = self.ss_method
        if hasattr(self.args, 'nvidia') and self.args.nvidia and self.ss_method in ['dino', 'simclr', 'vicreg']:
            actual_features_type = f"{self.ss_method}_nvidia"
        episode_info = {
            "dataset": self.dataset_name,
            "seed": self.args.seed,
            "features_type": actual_features_type,
            "integrated_features": self.integrated,
            "alpha": getattr(self.args, 'alpha', 0.5),
            "episode": self.episode,
            "selection_strategy": "BudgetedCoverageSelectionStrategy"
        }
        exp_dir = getattr(self.args, 'exp_dir', './experiments')
        summary_file = save_episode_weights_summary(
            episode_weights=self.episode_weights,
            episode_info=episode_info,
            exp_dir=exp_dir
        )
        self.episode_weights = []
        return summary_file

    def get_selected_indices(self):
        return self.buffer_indices

    def get_sorted_order(self):
        return self.new_order.tolist()
