import numpy as np
from knn import flat_mnist, precompute_D2
from sampling import sample_linear_extension
import time


def compute_value(
    method,
    sampler_kwargs,
    X_train_all, y_train_all,
    X_eval, y_eval,
    *,
    flatten_fn,
    group_index_dict=None,
    k=5,
    num_permutations=10000,
    rng=None,
    record_marginal_contrib=False,
    provider_ids=None,
):
    if rng is None:
        rng = np.random
    t0 = time.time()
    X_train_all = np.asarray(X_train_all)
    y_train_all = np.asarray(y_train_all, dtype=np.int64)
    X_eval = np.asarray(X_eval)
    y_eval = np.asarray(y_eval, dtype=np.int64)

    Ntr = X_train_all.shape[0]
    Ne = X_eval.shape[0]

    F_tr = flatten_fn(X_train_all)
    F_ev = flatten_fn(X_eval)
    D2_all = precompute_D2(F_ev, F_tr)

    C = int(max(int(y_train_all.max()), int(y_eval.max()))) + 1
    k = int(k)

    phi = np.zeros(Ntr, dtype=np.float64)
    row_idx = np.arange(Ne, dtype=np.int64)
    
    marginal_contrib_data = None
    if record_marginal_contrib:
        marginal_contrib_data = {
            "coalition_sizes": [],
            "marginal_contributions": [],
            "providers": [],
        }

    def reset_knn_buffers():
        best_dists = np.full((Ne, k), np.inf, dtype=np.float32)
        best_labels = np.full((Ne, k), -1, dtype=np.int16)
        counts = np.zeros((Ne, C), dtype=np.int16)
        m = np.zeros(Ne, dtype=np.int16)
        return best_dists, best_labels, counts, m

    outer_total = int(num_permutations)
    t = 0

    for t in range(1, outer_total + 1):
        if t % 100 == 0:
            print(f"[{method}] permutation {t} started...")
        perm = sample_linear_extension(method, rng=rng, **sampler_kwargs)
        perm = np.asarray(perm, dtype=np.int64)
        assert perm.shape[0] == Ntr

        best_dists, best_labels, counts, m = reset_knn_buffers()
        acc_prev = 0.0

        for j in range(Ntr):
            i = int(perm[j])
            yi = int(y_train_all[i])
            di = D2_all[:, i]

            worst_pos = np.argmax(best_dists, axis=1)
            worst_val = best_dists[row_idx, worst_pos]

            improve = (m < k) | (di < worst_val)
            if np.any(improve):
                rows = row_idx[improve]
                pos = worst_pos[improve]
                old_labels = best_labels[rows, pos]
                mask_old = (old_labels != -1)
                if np.any(mask_old):
                    counts[rows[mask_old], old_labels[mask_old]] -= 1
                best_dists[rows, pos] = di[improve]
                best_labels[rows, pos] = yi
                counts[rows, yi] += 1
                new_points = ~mask_old
                if np.any(new_points):
                    m[rows[new_points]] += 1

            preds = np.argmax(counts, axis=1)
            acc = float(np.mean(preds == y_eval))
            marginal_contrib = acc - acc_prev
            phi[i] += marginal_contrib
            
            if record_marginal_contrib and provider_ids is not None:
                coalition_size = j
                provider_id = provider_ids[i]
                marginal_contrib_data["coalition_sizes"].append(coalition_size)
                marginal_contrib_data["marginal_contributions"].append(marginal_contrib)
                marginal_contrib_data["providers"].append(provider_id)
            
            acc_prev = acc

    phi /= float(t)

    phi_group_sum = None
    if group_index_dict is not None:
        phi_group_sum = {}
        for g_name, idx in group_index_dict.items():
            idx = np.asarray(idx, dtype=np.int64)
            if idx.size == 0:
                phi_group_sum[g_name] = 0.0
            else:
                vals = phi[idx]
                phi_group_sum[g_name] = float(vals.sum())

    elapsed_seconds = float(time.time() - t0)

    result = {
        "phi_individual": phi,
        "phi_group_sum": phi_group_sum,
        "num_permutations": t,
        "elapsed_seconds": elapsed_seconds,
    }
    
    if record_marginal_contrib and marginal_contrib_data is not None:
        result["marginal_contribution"] = marginal_contrib_data
    
    return result


