import numpy as np
import pandas as pd
import torch
from xgboost import XGBClassifier
from sklearn.datasets import make_classification
from collections import Counter
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist
from sklearn.metrics import accuracy_score, f1_score
import os, sys
# Add parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class Curator:
    def __init__(self, X, y, sparse_labels: bool = False, catboost: bool = False):
        """
        X: Data to be evaluated (samples x features)
        y: True labels for samples (required to be numeric labels, e.g., 0/1)
        """
        self.X = X
        self.y = np.asarray(y)
        self._sparse_labels = sparse_labels
        self._gold_labels_probabilities = None  # Record prediction probability for true labels (each iteration)
        self._true_probabilities = None         # Maximum prediction probability for each iteration
        self.catboost = catboost

    def on_epoch_end(self, clf, device="cpu", iteration=1, **kwargs):
        """
        At the end of each iteration (or each tree in the model), use the current model to predict all samples,
        and record the prediction probability corresponding to the true label and the highest prediction probability for each sample.
        """
        x = self.X
        y = torch.tensor(self.y, device=device)
        probabilities = torch.tensor(clf.predict_proba(x), device=device)
        # Extract prediction probability corresponding to true labels
        gold_label_probabilities = probabilities[torch.arange(probabilities.shape[0]), y.to(torch.int64)]
        # Get maximum prediction probability for each sample
        true_probabilities = torch.max(probabilities, dim=1)[0]

        # Convert to numpy arrays and expand last dimension for subsequent stacking
        gold_label_probabilities = np.expand_dims(gold_label_probabilities.cpu().numpy(), axis=-1)
        true_probabilities = np.expand_dims(true_probabilities.cpu().numpy(), axis=-1)

        if self._gold_labels_probabilities is None:
            self._gold_labels_probabilities = gold_label_probabilities
        else:
            self._gold_labels_probabilities = np.hstack([self._gold_labels_probabilities, gold_label_probabilities])

        if self._true_probabilities is None:
            self._true_probabilities = true_probabilities
        else:
            self._true_probabilities = np.hstack([self._true_probabilities, true_probabilities])

    @property
    def gold_labels_probabilities(self):
        """Return prediction probability matrix for true labels with shape (n_samples, n_iterations)"""
        return self._gold_labels_probabilities

    @property
    def true_probabilities(self):
        """Return maximum prediction probability matrix with shape (n_samples, n_iterations)"""
        return self._true_probabilities

    @property
    def confidence(self):
        """Calculate average prediction probability (confidence) for each sample"""
        return np.mean(self._gold_labels_probabilities, axis=-1)

    @property
    def variability(self):
        """Calculate standard deviation of prediction probabilities (variability) for each sample"""
        return np.std(self._gold_labels_probabilities, axis=-1)

    @property
    def correctness(self):
        """Calculate the proportion of correct predictions (prediction probability > 0.5) for each sample across multiple iterations"""
        return np.mean(self._gold_labels_probabilities > 0.5, axis=-1)

    @property
    def aleatoric(self):
        """Calculate aleatoric uncertainty (irreducible noise) for each sample"""
        preds = self._gold_labels_probabilities
        return np.mean(preds * (1 - preds), axis=-1)
    


import numpy as np
from sklearn.metrics import balanced_accuracy_score

class BlockCuratorLite:
    """
    Only compute two block-level metrics:
      - bal_acc: balanced accuracy (mitigate class imbalance)
      - conf   : average confidence (default = average prediction probability for true labels)
    """
    def __init__(self, Xi, yi):
        self.Xi = Xi
        self.yi = np.asarray(yi)
        self._yhat_hist = []   # Each iteration: predicted labels (n_i,)
        self._gold_hist = []   # Each iteration: prediction probability for true labels (n_i,)

    def on_epoch_end(self, clf):
        proba = clf.predict_proba(self.Xi)            # (n_i, n_classes)
        yhat = np.argmax(proba, axis=1)               # Predicted labels
        gold = proba[np.arange(len(self.yi)), self.yi.astype(int)]  # Prediction probability for true labels
        self._yhat_hist.append(yhat)
        self._gold_hist.append(gold)

    def _majority_vote(self):
        """Majority voting on historical predictions (if only one time, equals that time)"""
        preds = np.stack(self._yhat_hist, axis=1)     # (n_i, n_iters)
        def mode1d(row):
            vals, cnts = np.unique(row, return_counts=True)
            return vals[np.argmax(cnts)]
        return np.apply_along_axis(mode1d, 1, preds)

    def metrics(self):
        assert len(self._yhat_hist) > 0, "Please call on_epoch_end(clf) first"
        y_pred_mv = self._majority_vote()

        # balanced accuracy; if single-class block causes exception, fallback to normal accuracy
        try:
            unique_true = np.unique(self.yi)
            unique_pred = np.unique(y_pred_mv)

            if not np.isin(unique_pred, unique_true).all():
                # If predicted labels contain classes not in true labels, fallback
                bal_acc = float(np.mean(y_pred_mv == self.yi))
            else:
                bal_acc = float(balanced_accuracy_score(self.yi, y_pred_mv))
                if np.isnan(bal_acc):
                    bal_acc = float(np.mean(y_pred_mv == self.yi))
        except Exception:
            bal_acc = float(np.mean(y_pred_mv == self.yi))

        # Block-level confidence: average prediction probability for true labels (global mean across samples and iterations)
        conf = float(np.mean(np.stack(self._gold_hist, axis=1)))  # Global mean of (n_i, n_iters)

        return {"bal_acc": bal_acc, "conf": conf}



from collections import Counter


def is_better(current: dict, best: dict, metrics: list) -> bool:
    """
    Compare current and best according to metrics order.
    metrics is a list of (key, mode) tuples, where mode ∈ {'max','min'}:
      - 'max' means higher key value is better
      - 'min' means lower key value is better

    If current is better than best on the first item, return True;
    If equal, check the next item, until there's a winner or all items are equal (return False).
    """
    for key, mode in metrics:
        curr_v = current[key]
        best_v = best.get(key)
        # If best doesn't have this key yet, consider current better
        if best_v is None:
            return True
        # First check if equal
        if curr_v == best_v:
            continue
        # Otherwise, compare by mode
        if mode == 'max' and curr_v > best_v:
            return True
        if mode == 'min' and curr_v < best_v:
            return True
        # If none of the above "better" conditions are met, current is definitely worse
        return False
    # All metrics are equal
    return False

def categorical_distance(generated_data, train_data, categorical_features):
    diff_matrix = np.zeros((generated_data.shape[0], train_data.shape[0]))
    for feature in categorical_features:
        gen_col = generated_data[feature].values[:, np.newaxis]
        train_col = train_data[feature].values[np.newaxis, :]
        diff_matrix += (gen_col != train_col).astype(int)
    return diff_matrix


def compute_dcr_and_entropy(train_data, generated_data, numerical_features, categorical_features):
    # 1. Standardize numerical features
    scaler = StandardScaler()
    train_num = scaler.fit_transform(train_data[numerical_features])
    gen_num = scaler.transform(generated_data[numerical_features])
    # 2. Calculate distance matrix
    num_dist = cdist(gen_num, train_num, metric='cityblock')
    if categorical_features:
        cat_dist = categorical_distance(generated_data, train_data, categorical_features)
    else:
        cat_dist = 0
    total_dist = num_dist + cat_dist
    # 3. Nearest neighbor indices
    nearest_idx = np.argmin(total_dist, axis=1)
    # 4. Cluster frequencies and empirical entropy
    x = train_data.shape[0]
    counts = np.bincount(nearest_idx, minlength=x)
    probs = counts / counts.sum()
    probs_nonzero = probs[probs > 0]
    entropy = -np.sum(probs_nonzero * np.log(probs_nonzero))
    return nearest_idx, counts, entropy


def entropy_ratio_with_cap(freqs, tau=0.75, r_max=0.6):
    freqs = np.array(freqs, dtype=float)
    p = freqs / freqs.sum() if freqs.sum()>0 else np.zeros_like(freqs)
    # Calculate normalized entropy
    nz = p>0
    H = -np.sum(p[nz] * np.log(p[nz]))
    H_norm = H / np.log(len(freqs))
    return 1.0 - H_norm

def find_elbow_with_global(p, metric='difference'):
    """
    Global perspective elbow method: only find maximum jump elbow in clusters with probability greater than average.

    Parameters:
        p: non-negative distribution array, no need to normalize;
        metric: 'difference' or 'ratio', use difference or ratio to measure jump respectively.

    Returns:
        k: first k elements are considered core clusters.
    """
    p = np.asarray(p, dtype=float)
    total = p.sum()
    n = len(p)
    if total == 0 or n < 2:
        return n

    # Normalize and sort in descending order
    p = p / total
    sorted_p = np.sort(p)[::-1]

    # Global average probability
    mean_prob = 1.0 / n

    # Calculate jump values
    if metric == 'difference':
        jumps = sorted_p[:-1] - sorted_p[1:]
    elif metric == 'ratio':
        jumps = sorted_p[:-1] / (sorted_p[1:] + 1e-20)
    else:
        raise ValueError("metric must be 'difference' or 'ratio'")

    # Only consider jumps at clusters greater than average
    valid = sorted_p[:-1] > mean_prob
    if not np.any(valid):
        idx = int(np.argmax(jumps))
    else:
        # Find maximum jump in valid region
        masked = np.where(valid, jumps, -np.inf)
        idx = int(np.argmax(masked))

    # Return number of core clusters
    return idx + 1

def find_core_by_hhi(p):
    """
    Use Herfindahl–Hirschman Index (HHI) to calculate equivalent number of categories k.
    p: non-negative distribution array, no need to normalize.
    Return value k represents the number of core features (floor).
    """
    p = np.asarray(p, dtype=float)
    total = p.sum()
    if total == 0:
        return 0
    # Normalize to probability distribution
    p = p / total

    # Calculate HHI and effective number of categories Neff
    hhi = np.sum(p ** 2)
    neff = 1.0 / hhi

    # Number of core features k
    k = int(np.floor(neff))
    return k

def gini_ratio_with_cap(p):
    """
    Calculate Gini coefficient of probability vector p.
    p: frequency array that may not be normalized
    """
    p = np.asarray(p, dtype=float)
    if p.sum() == 0:
        return 0.0
    p = p / p.sum()
    # Sort in descending order
    
    sorted_p = np.sort(p)
    n = len(sorted_p)
    # Calculate Gini = (1/(n-1)) * sum_{i=1}^n (2i - n - 1) * p_(i)
    indices = np.arange(1, n+1)
    return (1.0 / (n - 1)) * np.sum((2 * indices - n - 1) * sorted_p)

def find_elbow_with_bias(p, metric='difference', bias=0.1):
    """
    Global perspective elbow method with early truncation incentive:
    Calculate adjacent jumps in clusters with probability greater than average, then use "incentive function" to encourage selecting positions closer to the front.

    Parameters:
        p: non-negative distribution array, no need to normalize;
        metric: 'difference' or 'ratio';
        bias: truncation incentive coefficient, range [0,1), larger values favor front-end truncation.

    Returns:
        k: first k elements are considered core clusters.
    """
    p = np.asarray(p, dtype=float)
    total = p.sum()
    n = len(p)
    if total == 0 or n < 2:
        return n

    # Normalize and sort in descending order
    p = p / total
    sorted_p = np.sort(p)[::-1]

    # Global average probability
    mean_prob = 1.0 / n

    # Calculate jump values
    if metric == 'difference':
        jumps = sorted_p[:-1] - sorted_p[1:]
    elif metric == 'ratio':
        jumps = sorted_p[:-1] / (sorted_p[1:] + 1e-20)
    else:
        raise ValueError("metric must be 'difference' or 'ratio'")

    # Valid region: only consider clusters with probability greater than average
    valid = sorted_p[:-1] > mean_prob
    # Calculate scores for valid region
    # Position normalized to [0,1): idx/n
    positions = np.arange(n-1) / (n-1)
    # Incentive function: score = jumps * (1 - bias * position)
    scores = np.where(valid, jumps * (1 - bias * positions), -np.inf)

    idx = int(np.argmax(scores))
    return idx + 1



def _fixed_quota_per_block(weights, block_size):
    """
    Calculate "fixed quota per block" according to given weights (e.g., 5:3:2):
    First multiply by block_size proportionally, then round; if total > block_size, finally truncate uniformly to block_size.
    """
    w = np.asarray(weights, dtype=float)
    if w.sum() <= 0 or block_size <= 0:
        return [0] * len(w)
    raw = block_size * w / w.sum()
    quota = np.rint(raw).astype(int)              # Round to get quota per block
    # If total after rounding > block_size, truncate uniformly to block_size (simple and crude, consistent with "discard excess")
    extra = int(quota.sum() - block_size)
    if extra > 0:
        # Start reducing from highest weight until total == block_size (also keep it simple)
        order = np.argsort(-w)
        for i in order:
            if extra == 0: break
            take = min(extra, quota[i])
            quota[i] -= take
            extra -= take
    # If total < block_size, don't supplement, allow tail blocks to be smaller (following your "simple is fine" principle)
    return quota.tolist()

def build_fixed_ratio_blocks_simple(sel_indices, nearest_idx, selected_classes,
                                    weights, samples_per_block, random_state=42):
    """
    Construct as many complete blocks as possible according to [fixed ratio].
    - Each block samples from each cluster according to fixed quota without replacement;
    - If any cluster is exhausted, stop constructing complete blocks (tail insufficient for one complete block is directly discarded).
    Returns: blocks (list of local index lists relative to X_sel), ratio quota, maximum complete blocks B_max.
    """
    rng = np.random.default_rng(random_state)

    # Local anchor sequence of redundant subset in X_sel
    sel_nearest = nearest_idx[sel_indices]

    # Local index pool for each cluster
    cluster_local = {c: np.flatnonzero(sel_nearest == c).tolist() for c in selected_classes}
    for c in cluster_local:
        rng.shuffle(cluster_local[c])   # Shuffle then pop() without replacement

    # 1) Fixed quota per block (round, truncate if exceeded)
    quota = _fixed_quota_per_block(weights, samples_per_block)  # e.g., [k*5/10, k*3/10, k*2/10] rounded
    # Allow some quota to be 0 (when k is very small)

    # 2) Maximum complete blocks that can be constructed (stop when any cluster is exhausted first)
    feasible = []
    for q, c in zip(quota, selected_classes):
        if q > 0:
            feasible.append(len(cluster_local[c]) // q)
    B_max = min(feasible) if feasible else 0

    # 3) Construct B_max complete blocks; directly discard tail remainder
    blocks = []
    for _ in range(B_max):
        block = []
        for q, c in zip(quota, selected_classes):
            if q > 0:
                block.extend([cluster_local[c].pop() for _ in range(q)])
        # If rounding error causes this block to exceed samples_per_block, truncate to k (discard excess)
        if len(block) > samples_per_block:
            block = block[:samples_per_block]
        blocks.append(block)

    return blocks, quota, B_max



def filter_by_group_quality(
    X_small, y_small, X_large, y_large,
    numerical_features, categorical_features,
    samples_per_block=20,
    n_iterations=10,
    ratio2=0.75,
    curate = True   
):
    """
    Information entropy-block filtering process based on Curator (remove lambda_threshold):
    1. Clustering and calculate frequency counts -> freqs
    2. Calculate entropy increase Δ(t) for each cluster sorted by freq, take t* = argmax Δ
    3. Select top t* classes (high redundancy clusters)
    4. Block the data of these classes
    5. Multiple rounds of Curator evaluation for each block avg_correctness and avg_confidence
    6. Call get_optimal_r to calculate retention ratio r_opt -> N_keep
    7. Calculate needed blocks blocks_needed
    8. Sort by (corr, conf) and select top blocks
    9. Merge top blocks with other class data, return final indices
    """
    # Train base model
    clf = XGBClassifier(n_estimators=100)
    clf.fit(X_small, y_small)

    # Organize DataFrame
    train_df = X_small.copy()
    gen_df = X_large.copy()

    # 1. Clustering and frequency calculation
    nearest_idx, counts, entropy = compute_dcr_and_entropy(
        train_data=train_df,
        generated_data=gen_df,
        numerical_features=numerical_features,
        categorical_features=categorical_features
    )
    freqs = counts / counts.sum()
    
    # 2. Entropy increase rate maximization to select t*
    # Sort frequencies
    order = np.argsort(-freqs)
    p_sorted = freqs[order]
    ratio = gini_ratio_with_cap(p_sorted)
    
    print(p_sorted)

    cum_freqs = np.cumsum(freqs[order])
    t_star = np.searchsorted(cum_freqs, ratio) + 1

    print("gini = ", ratio)
    print("t_star = ", t_star)
    selected_classes = order[:t_star]
    # 3. Filter data subset that needs processing
    sel_mask = np.isin(nearest_idx, selected_classes)

    # Mark redundant clusters
    if t_star > 0:
        selected_classes = order[:t_star]
        sel_mask = np.isin(nearest_idx, selected_classes)
    else:
        # No redundant clusters: all samples are non-redundant
        sel_mask = np.zeros_like(nearest_idx, dtype=bool)
        print("t_star=0, all samples considered as non-redundant clusters entering curate process")
    
    print("Redundant cluster size:", t_star, "/", len(y_small), "Redundant data size:", np.sum(sel_mask), "+",len(y_large) - np.sum(sel_mask), "/", len(y_large))
    print("------------------------------------------------------")

    # Establish non-redundant and redundant subsets
    sel_indices = np.flatnonzero(sel_mask)
    other_indices = np.flatnonzero(~sel_mask)
    n_sel = len(sel_indices)

    from filtering import data_centric_curation
    final_nonred_idxs = np.array([], dtype=int)
    if curate and len(other_indices) > 0:
        X_other = X_large.iloc[other_indices]
        y_other = np.array(y_large)[other_indices]
        easy, ambig, _ = data_centric_curation(
            X_small, y_small,
            X_other, y_other,
            curation_metric='aleatoric',
            retrain=False,
            nest=100,
            ratio=0
        )
        final_nonred_idxs = other_indices[np.array(easy, dtype=int)]

    beta = len(final_nonred_idxs)/(len(y_large) - np.sum(sel_mask))
    print(f"beta = {beta:.4f}")

    # Curator evaluation and filtering within redundant cluster blocks
    selected_redundant_idxs = np.array([], dtype=int)
    if n_sel > 0:
        X_sel = X_large.iloc[sel_indices].reset_index(drop=True)
        y_sel = np.array(y_large)[sel_indices]

        # Block evaluation
        block_count = max(1, n_sel // samples_per_block)

        # ---- Pre-determine retention strength and block count (with numerical protection) ----
        weights = [np.sum(nearest_idx[sel_indices] == c) for c in selected_classes]

        blocks, quota, B_max = build_fixed_ratio_blocks_simple(
            sel_indices=sel_indices,
            nearest_idx=nearest_idx,
            selected_classes=selected_classes,     # Already sorted by frequency in descending order
            weights=weights,                       # Ratio 5:3:2 comes from this
            samples_per_block=samples_per_block,
            random_state=42
        )
        print(f"[Blocks] per-block quota={quota}, B_max={B_max}, built={len(blocks)}")

        block_metrics = []
        
        for blk in blocks:
            Xi, yi = X_sel.iloc[blk], y_sel[blk]
            curator_blk = BlockCuratorLite(Xi, yi)
            for _ in range(n_iterations):
                curator_blk.on_epoch_end(clf=clf)
            m = curator_blk.metrics()  # Only contains bal_acc and conf
            block_metrics.append(m)

        # Sort using only these two items (first by bal_acc, then by conf for tie-breaking)
        sorted_blocks = sorted(
            enumerate(block_metrics),
            key=lambda x: (x[1]['bal_acc'], x[1]['conf']),
            reverse=True
        )

        # Calculate retention ratio and block count
        r_opt = 0.15 * np.log(ratio) + 0.55
        print(f"ratio 2 = {r_opt:.4f}")
        N_keep = int(np.floor(r_opt * n_sel)) + 1
        blocks_needed = int(np.ceil(N_keep / samples_per_block))

        top_block_indices = [idx for idx, _ in sorted_blocks[:blocks_needed]]
        # Accumulate retained redundant sample indices
        selected_redundant_idxs = np.concatenate([
            sel_indices[blocks[bi]] for bi in top_block_indices
        ]).astype(int)

        
    print("Filtered data size:", len(selected_redundant_idxs), "(", np.sum(sel_mask), ") +" , len(final_nonred_idxs), "(", len(y_large) - np.sum(sel_mask), ") /", len(y_large))
    # Merge results
    final_indices = np.concatenate([selected_redundant_idxs, final_nonred_idxs]).astype(int)

    return final_indices,ratio,ratio2,beta


def grid_search_filter(
    X_small, y_small, X_large, y_large,
    X_test, y_test,
    numerical_features,
    categorical_features,
    block_sizes=[10,20,30],
    n_iterations=10,
    ratio2=0.75,
    curate = True
):
    best = {
        'score': -1, 
        'surprisal': None, 
        'entropy': None, 
        'ale': None,
        'block_sizes': 0,
        'ratio1': 0,
        'ratio2': 0,
        'beta': 0
    }

    for bs in block_sizes:
        idxs, ratio1, ratio2, beta = filter_by_group_quality(
            X_small,y_small, X_large, y_large, 
            numerical_features = numerical_features, 
            categorical_features = categorical_features,
            samples_per_block=bs,
            n_iterations=n_iterations,
            ratio2=ratio2,
            curate = curate
        )
        if len(idxs)==0: continue
        Xf=X_large.iloc[idxs]; yf=np.array(y_large)[idxs]
        clf2 = XGBClassifier(n_estimators=50)
        clf2.fit(Xf, yf)

        # 1) Accuracy
        preds = clf2.predict(X_test)
        acc   = accuracy_score(preds, y_test)

        # 2) Prediction probability matrix (n_samples, n_classes)
        probs = clf2.predict_proba(X_test)

        print(probs.shape)
        # 3) Gold‐label probability p_i
        #    (Take corresponding column for true label y_large in each row)
        gold_probs = probs[np.arange(len(y_test)), y_test.astype(int)]
        # 4) Self-information (Surprisal) s_i = -log p_i
        surprisal = -np.log(gold_probs + 1e-12)
        mean_surprisal = surprisal.mean()

        # 5) Predictive entropy H_i = -∑_k p_{i,k} log p_{i,k}
        predictive_entropy = -np.sum(probs * np.log(probs + 1e-12), axis=1)
        mean_entropy = predictive_entropy.mean()

        # 6) Aleatoric uncertainty u_i = p_i (1 - p_i)
        aleatoric = gold_probs * (1 - gold_probs)
        mean_aleatoric = aleatoric.mean()

        current = {
            'score': acc,
            'surprisal': mean_surprisal,
            'entropy': mean_entropy,
            'ale': mean_aleatoric,
            'block_sizes': bs,
            'idxs':idxs,
            'ratio1': ratio1,
            'ratio2': ratio2,
            'beta': beta
        }
        # Define priority list
        metrics = [
            ('surprisal', 'min'),   # Lower surprisal is better
            ('score', 'max'),       # Higher accuracy is better
            ('entropy', 'min'),     # Lower predictive entropy is better
            ('ale', 'min'),         # Lower aleatoric is better
        ]

        if is_better(current, best, metrics):
            best.update(current)

        print(f"bs={bs}, acc={acc:.3f}, surprisal={mean_surprisal:.3f}, kept={len(idxs)}")

    print("best: score", best['score'], "surprisal:", best['surprisal'], "block size:", best['block_sizes'], "idxs:", len(best['idxs']), "ratio1:", best['ratio1'], "beta:", best['beta'])
    return best