import math
import random
from collections import defaultdict

import numpy as np
import torch
import dgl

def build_similarity_graph_from_leaves(
    binned_feature_tensor: torch.Tensor,
    top_k: int = 20,
    use_idf: bool = True,
    sym: str = 'max',
    max_per_leaf: int = 500,
):
    leaf_indices = binned_feature_tensor.cpu().numpy()  # [N, T]
    N, T = leaf_indices.shape

    buckets = [defaultdict(list) for _ in range(T)]
    for i in range(N):
        for t in range(T):
            l = int(leaf_indices[i, t])
            buckets[t][l].append(i)

    idf = [dict() for _ in range(T)]
    if use_idf:
        for t in range(T):
            for l, nodes in buckets[t].items():
                df = len(nodes)
                idf[t][l] = math.log((N + 1) / (df + 1.0)) + 1.0
    else:
        for t in range(T):
            for l in buckets[t].keys():
                idf[t][l] = 1.0

    pair_weight = {}
    for i in range(N):
        counter = defaultdict(float)
        for t in range(T):
            l = int(leaf_indices[i, t])
            neigh = buckets[t][l]
            if max_per_leaf is not None and max_per_leaf > 0 and len(neigh) > max_per_leaf:
                neigh = random.sample(neigh, max_per_leaf)
            w = idf[t][l]
            for j in neigh:
                if j == i:
                    continue
                counter[j] += w
        if not counter:
            continue
        top_items = sorted(counter.items(), key=lambda kv: -kv[1])[:top_k]
        for j, w_ij in top_items:
            u, v = (i, j) if i < j else (j, i)
            if u == v:
                continue
            if (u, v) not in pair_weight:
                pair_weight[(u, v)] = w_ij
            else:
                pair_weight[(u, v)] = max(pair_weight[(u, v)], w_ij)

    if len(pair_weight) == 0:
        g_sim = dgl.graph(([], []), num_nodes=N)
        g_sim.edata['w'] = torch.zeros((0, 1), dtype=torch.float32)
        return g_sim, g_sim.edata['w'].view(-1)

    src, dst, w_list = [], [], []
    for (u, v), w_uv in pair_weight.items():
        w_final = w_uv
        src.extend([u, v])
        dst.extend([v, u])
        w_list.extend([w_final, w_final])

    g_sim = dgl.graph((torch.tensor(src, dtype=torch.int64),
                       torch.tensor(dst, dtype=torch.int64)),
                      num_nodes=N)
    eweight = torch.tensor(w_list, dtype=torch.float32).view(-1, 1)
    g_sim.edata['w'] = eweight
    return g_sim, eweight.view(-1)
