from tqdm import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from joblib import Parallel, delayed

from main.shapley.helpers.helper_knn import rank_neighbor
from main.utils.random import set_random_seed


def best_threshold(actual, predicted):
    actual = np.array(actual)
    predicted = np.array(predicted)
    thresholds = np.unique(predicted)

    best_tau = None
    best_error = len(actual) + 1

    for tau in thresholds:
        preds = (predicted > tau).astype(int)
        errors = np.sum(preds != actual)
        if errors < best_error:
            best_error = errors
            best_tau = tau

    return best_tau


def threshold_majority(ps_shapleys, 
                       X_train, y_train, X_val, y_val,
                       K=5,
                       dis_metric='cosine',
                       n_samples=10000,
                       seed=None):
    """
    Calculate the Shapley thresholds where the actual utility (0 or 1) is determined by majority vote.
    Args:
        ps_shapleys (torch.Tensor): Per-validation-datum Shapley values.
        X_train (torch.Tensor): Training data features.
        y_train (torch.Tensor): Training data labels.
        X_val (torch.Tensor): Validation data features.
        y_val (np.ndarray): Validation data labels.
        K (int): Number of neighbors to consider.
        dis_metric (str): Distance metric to use for neighbor calculation.
        n_samples (int): Number of samples to use for determining thresholds.
    Returns:
        np.ndarray: Thresholds for each validation datum.
    """
    if seed is not None:
        set_random_seed(seed)

    def is_most_frequent(label, topk_labels):
        counts = np.bincount(topk_labels, minlength=np.unique(y_train).size)
        max_count = counts.max()
        most_frequent = np.flatnonzero(counts == max_count)
        return label in most_frequent

    samples = []
    n_train = len(y_train)
    n_val = len(y_val)
    for _ in range(n_samples):
        sample_size = np.random.randint(1, n_train + 1)
        sample = np.random.choice(n_train, size=sample_size, replace=False)
        samples.append(sample)

    #print(samples[0])
    thresholds = np.zeros(n_val)

    for j in tqdm(range(n_val)):
        order = rank_neighbor(X_val[j], X_train)
        actual_utilities = []
        predicted_utilities = []
        for s_idx, sample in enumerate(samples):
            rank = np.empty_like(order)
            rank[order] = np.arange(len(order))
            sample_ranks = rank[sample]
            n_neighbors = min(K, len(sample))

            topK, topK_idx = torch.topk(torch.Tensor(sample_ranks), n_neighbors, largest=False)
            topK = sample[topK_idx.numpy()]
            
            #topK_idx = np.argpartition(sample_ranks, n_neighbors-1)[:n_neighbors]
            #topK = sample[topK_idx]
            if is_most_frequent(y_val[j], y_train[topK]):
                actual_utilities.append(1)
            else:
                actual_utilities.append(0)


            predicted_utilities.append(ps_shapleys[sample, j].sum())

        thresholds[j] = best_threshold(actual_utilities, predicted_utilities)
        
    return thresholds


# I haven't checked the following, but it should be identical to the one above as I tested.
def rank_neighbor_torch(xv: torch.Tensor, Xtr: torch.Tensor, metric='cosine') -> torch.Tensor:
    if metric == 'cosine':
        # sim = Xtr @ xv   (no xv normalization)
        sim = Xtr @ xv
        denom = torch.linalg.norm(Xtr, dim=1)  # ||Xtr_i||
        # match NumPy's divide-by-zero behavior: distance = +inf where denom == 0
        distance = -sim / denom
        distance[denom == 0] = torch.inf
        order = torch.argsort(distance, dim=0, descending=False)

    else:
        # true Euclidean norm (not squared) to match your NumPy
        distance = torch.linalg.norm(Xtr - xv.unsqueeze(0), dim=1)
        order = torch.argsort(distance, dim=0, descending=False)

    return order.long()


@torch.no_grad()
def threshold_majority_parallel(ps_shapleys,
                                X_train, y_train, X_val, y_val,
                                K=5,
                                n_samples=10000,
                                batch_size=1024,
                                seed=None):
    device = ps_shapleys.device
    # Keep tensors on device; DO NOT move y_val (NumPy) to device
    X_train = X_train.to(device)
    y_train = y_train.to(device)
    X_val   = X_val.to(device)
    y_val = y_val.cpu().numpy()

    n_train = len(y_train)
    n_val   = len(y_val)

    if seed is not None:
        set_random_seed(seed)  # affects np.random.* just like your original

    # --- same random sampling as your original (np.random) ---
    samples = []
    for _ in range(n_samples):
        sample_size = np.random.randint(1, n_train + 1)
        sample = np.random.choice(n_train, size=sample_size, replace=False)
        samples.append(sample)

    #print(samples[0])
    thresholds = np.zeros(n_val, dtype=float)
    n_classes  = int(torch.unique(y_train).numel())

    # --- build neighbor orders EXACTLY like your non-parallel version ---
    #     (call your rank_neighbor for EACH validation point)
    def make_order_block(vb0, vb1):
        cols = []
        for j in range(vb0, vb1):
            ord_j = rank_neighbor_torch(X_val[j], X_train)  # your helper; best->worst
            if not isinstance(ord_j, torch.Tensor):
                ord_j = torch.as_tensor(ord_j, device=device, dtype=torch.long)
            else:
                ord_j = ord_j.to(device=device, dtype=torch.long)
            cols.append(ord_j.view(-1, 1))  # [n_train, 1]
        return torch.cat(cols, dim=1)        # [n_train, B]

    for vb0 in tqdm(range(0, n_val, batch_size), desc='Batch'):
        vb1 = min(vb0 + batch_size, n_val)
        B   = vb1 - vb0

        # neighbor order & ranks (0 = best), using your rank_neighbor per j
        order_B = make_order_block(vb0, vb1)                         # [n_train, B]
        ranks_B = torch.empty_like(order_B)
        arange_nt = torch.arange(n_train, device=device).view(-1, 1).expand(-1, B)
        ranks_B.scatter_(0, order_B, arange_nt)

        # y labels for this batch (NumPy -> torch)
        yv_batch = torch.as_tensor(y_val[vb0:vb1], device=device, dtype=torch.long)  # [B]

        # collect actual/predicted across all samples (concat minibatches)
        actual_cat    = []
        predicted_cat = []

        for sb0 in range(0, n_samples, batch_size):
            sb1 = min(sb0 + batch_size, n_samples)
            Mb  = sb1 - sb0
            batch = samples[sb0:sb1]

            # pad this sample mini-batch to its own max length
            Lb = max(len(s) for s in batch)
            samples_idx = np.full((Mb, Lb), -1, dtype=np.int64)
            for t, s in enumerate(batch):
                samples_idx[t, :len(s)] = s

            samples_idx        = torch.as_tensor(samples_idx, device=device, dtype=torch.long)  # [Mb, Lb]
            samples_mask       = samples_idx.ne(-1)                                             # [Mb, Lb]
            samples_idx_clamped= samples_idx.clamp_min(0)                                       # [Mb, Lb]

            # ----- actual utilities (KNN majority within sample) -----
            # gather ranks: ranks_B.T: [B, n_train]  -> gather with [B, Mb, Lb]
            ranks_exp = ranks_B.T.unsqueeze(1).expand(B, Mb, n_train)           # [B,Mb,n_train]
            idx_3d   = samples_idx_clamped.unsqueeze(0).expand(B, -1, -1)       # [B,Mb,Lb]
            ranks_on_samples = torch.gather(ranks_exp, 2, idx_3d)               # [B,Mb,Lb]
            BIG   = n_train + 1
            mask3 = samples_mask.unsqueeze(0).expand(B, -1, -1)                  # [B,Mb,Lb]
            ranks_on_samples = torch.where(mask3, ranks_on_samples,
                                           torch.full_like(ranks_on_samples, BIG))

            K_eff   = min(K, Lb)
            topk_pos= torch.topk(-ranks_on_samples, k=K_eff, dim=2,
                                 largest=True, sorted=False).indices             # [B,Mb,K]
            gathered_ranks = torch.gather(ranks_on_samples, 2, topk_pos)        # [B,Mb,K]
            valid_topk     = gathered_ranks.ne(BIG)

            # labels at sample positions
            y_samples   = y_train.index_select(0, samples_idx_clamped.view(-1)).view(Mb, Lb)  # [Mb,Lb]
            y_s_3d      = y_samples.unsqueeze(0).expand(B, -1, -1)                           # [B,Mb,Lb]
            topk_labels = torch.gather(y_s_3d, 2, topk_pos)                                   # [B,Mb,K]

            one_hot = F.one_hot(topk_labels.clamp_min(0), num_classes=n_classes).to(torch.int32)  # [B,Mb,K,C]
            one_hot = one_hot * valid_topk.unsqueeze(-1).to(one_hot.dtype)
            counts  = one_hot.sum(dim=2)                                                         # [B,Mb,C]

            yvj         = yv_batch.view(B, 1, 1)                   # [B,1,1]
            yvj_idx     = yvj.expand(-1, Mb, 1)                    # [B,Mb,1]
            yvj_counts  = counts.gather(2, yvj_idx).squeeze(2)     # [B,Mb]
            max_counts  = counts.max(dim=2).values                                              # [B,Mb]
            actual_MB   = (yvj_counts == max_counts).to(torch.int32)                            # [B,Mb]

            # ----- predicted utilities (sum Shapley over FULL sample) -----
            ps_b   = ps_shapleys[:, vb0:vb1].T                                                 # [B,n_train]
            vals   = torch.gather(ps_b.unsqueeze(1).expand(-1, Mb, -1), 2, idx_3d)             # [B,Mb,Lb]
            vals   = torch.where(mask3, vals, torch.zeros(1, device=device, dtype=vals.dtype))
            pred_MB= vals.sum(dim=2)                                                            # [B,Mb]



            # stash as [Mb,B] on CPU (to reuse your best_threshold)
            actual_cat.append(actual_MB.transpose(0, 1).contiguous().cpu().numpy())
            predicted_cat.append(pred_MB.transpose(0, 1).contiguous().cpu().numpy())

        # concatenate all sample minibatches along Mb
        actual_full    = np.concatenate(actual_cat, axis=0)      # [M, B]
        predicted_full = np.concatenate(predicted_cat, axis=0)   # [M, B]

        # pick best threshold per validation point (exactly your routine)
        for b in range(B):
            thresholds[vb0 + b] = best_threshold(actual_full[:, b], predicted_full[:, b])

    return thresholds