import pandas as pd
import numpy as np
import igraph as ig
import leidenalg as la


def compute_difficulty(df: pd.DataFrame) -> pd.Series:
    # Compute difficulty score for each item (error rate).
    difficulty_scores = 1 - df.mean(axis=1)
    difficulty_scores.name = "difficulty"
    return difficulty_scores


def compute_uniqueness(df: pd.DataFrame) -> pd.Series:
    n_questions, n_models = df.shape
    scores = []

    for i in range(n_questions):
        v_i = df.iloc[i].values
        total_conditional_entropy = 0.0

        for j in range(n_questions):
            if i == j:
                continue
            v_j = df.iloc[j].values
            h_cond = 0.0

            idx1 = np.where(v_i == 1)[0]
            if len(idx1) > 0:
                p1 = v_j[idx1].mean()
                if 0 < p1 < 1:
                    h1 = -p1 * np.log2(p1) - (1 - p1) * np.log2(1 - p1)
                    h_cond += (len(idx1) / n_models) * h1

            idx0 = np.where(v_i == 0)[0]
            if len(idx0) > 0:
                p0 = v_j[idx0].mean()
                if 0 < p0 < 1:
                    h0 = -p0 * np.log2(p0) - (1 - p0) * np.log2(1 - p0)
                    h_cond += (len(idx0) / n_models) * h0

            total_conditional_entropy += h_cond

        avg_cond_entropy = total_conditional_entropy / (n_questions - 1) if n_questions > 1 else 0.0

        score = avg_cond_entropy
        scores.append(score)

    return pd.Series(scores, index=df.index, name="uniqueness")


def compute_risk(df: pd.DataFrame, eps: float = 1e-9) -> pd.Series:
    # Compute risk score for each item (overlap of error patterns).
    E = 1 - df.values.astype(int)
    Q, N = E.shape
    errs_per_model = E.sum(axis=0)
    w = np.log((Q + 1.0) / (errs_per_model + 1.0 + eps))
    risk = np.zeros(Q, dtype=float)
    for i in range(Q):
        Ei = E[i].astype(bool)
        total_jw = 0.0
        for k in range(Q):
            if k == i:
                continue
            Ek = E[k].astype(bool)
            intersection_mask = np.logical_and(Ei, Ek)
            union_mask = np.logical_or(Ei, Ek)
            inter_w = w[intersection_mask].sum()
            union_w = w[union_mask].sum()
            jw = inter_w / (union_w + eps)
            total_jw += jw
        risk[i] = total_jw / (Q - 1) if Q > 1 else 0.0
    return pd.Series(risk, index=df.index, name="risk")


def compute_surprise(df: pd.DataFrame, eps: float = 1e-9, alpha: float = 1.0) -> pd.DataFrame:
    # Compute surprise score for each item (TF-IDF based).
    X = df.values.astype(int)
    Q, N = X.shape

    a = df.mean(axis=0).values  
    r = 1.0 - a                

    def zscore(x: np.ndarray, eps: float = 1e-9):
        return (x - x.mean()) / (x.std() + eps)

    a_z = zscore(a, eps)  
    r_z = zscore(r, eps)  

    n_wrong = (1 - X).sum(axis=1)  
    n_right = X.sum(axis=1)        

    tf_wrong = np.array([a_z[X[i] == 0].mean() if n_wrong[i] > 0 else 0.0 for i in range(Q)])  
    tf_right = np.array([r_z[X[i] == 1].mean() if n_right[i] > 0 else 0.0 for i in range(Q)])  

    idf_wrong = np.log((N + 1) / (n_wrong + 1)) ** alpha
    idf_right = np.log((N + 1) / (n_right + 1)) ** alpha

    se_tfidf = tf_wrong * idf_wrong          
    ss_tfidf = tf_right * idf_right         
    surprise_tfidf = 0.5 * (se_tfidf + ss_tfidf)

    return pd.DataFrame({
        "n_wrong": n_wrong, "n_right": n_right,
        "tf_wrong": tf_wrong, "tf_right": tf_right,
        "idf_wrong": idf_wrong, "idf_right": idf_right,
        "surprise_error": se_tfidf, "surprise_success": ss_tfidf,
        "surprise": surprise_tfidf,
    }, index=df.index)



# ================================
# ======= cluster analysis =======
#  Representativeness and bridge
def _group_identical_rows(df: pd.DataFrame):
    # Group identical row vectors.
    X = df.values.astype(np.uint8)
    Q = X.shape[0]
    sig_to_gid = {}
    groups = []
    row_to_group = np.empty(Q, dtype=int)
    for i in range(Q):
        sig = X[i].tobytes()
        gid = sig_to_gid.get(sig)
        if gid is None:
            gid = len(groups)
            sig_to_gid[sig] = gid
            groups.append([i])
        else:
            groups[gid].append(i)
        row_to_group[i] = gid
    unique_idx = [g[0] for g in groups]
    return groups, unique_idx, row_to_group


def phi_similarity_matrix(df: pd.DataFrame, eps: float = 1e-12) -> tuple[np.ndarray, list]:
    # Compute Phi coefficient similarity matrix 
    X = df.values.astype(np.int8)  # (Q × M)
    Q, M = X.shape
    S = np.zeros((Q, Q), dtype=np.float32)

    for i in range(Q):
        xi = X[i]
        for j in range(i, Q):
            yj = X[j]

            a = np.sum((xi == 1) & (yj == 1))
            b = np.sum((xi == 1) & (yj == 0))
            c = np.sum((xi == 0) & (yj == 1))
            d = np.sum((xi == 0) & (yj == 0))

            denom = np.sqrt((a+b) * (c+d) * (a+c) * (b+d)) + eps
            phi = (a*d - b*c) / denom

            phi = float(np.clip(phi, -1.0, 1.0))
            sim = (phi + 1.0) * 0.5 

            S[i, j] = sim
            S[j, i] = sim   

    np.fill_diagonal(S, 1.0)
    return S, df.index.to_list()


def _mask_similarity(S: np.ndarray, thresh: float) -> np.ndarray:
    # Threshold similarity matrix.
    S_masked = S.copy()
    S_masked[S_masked < thresh] = 0.0
    np.fill_diagonal(S_masked, 0.0)
    return S_masked


def cluster_with_leiden(
    S: np.ndarray,
    groups: list[list[int]],
    unique_idx: list[int],
    threshold: float = 0.9,
) -> np.ndarray:
    # Leiden community detection on thresholded graph of unique patterns.    
    Su = S[np.ix_(unique_idx, unique_idx)].copy()
    Su = _mask_similarity(Su, threshold)
    iu, ju = np.where(np.triu(Su, k=1) > 0.0)
    weights = Su[iu, ju].astype(float)
    n_unique = len(unique_idx)
    g = ig.Graph(n=n_unique, edges=list(zip(iu.tolist(), ju.tolist())), directed=False)
    if len(weights) > 0:
        g.es["weight"] = weights
    if g.ecount() == 0:
        labels_unique = np.arange(n_unique, dtype=int)
    else:
        part = la.find_partition(
            g,
            la.RBConfigurationVertexPartition,
            weights=g.es["weight"] if g.ecount() > 0 else None,
            resolution_parameter=1.0,
            seed=123,
        )
        labels_unique = np.asarray(part.membership, dtype=int)
    Q = S.shape[0]
    labels_full = np.empty(Q, dtype=int)
    for gid, lbl in enumerate(labels_unique):
        for row_idx in groups[gid]:
            labels_full[row_idx] = lbl
    return labels_full


def compute_typicality(
    S: np.ndarray,
    labels: np.ndarray,
    index: list,
    threshold: float = 0.9,
    singleton_value: float = 0,
):
    """
    Compute typicality for each item (average similarity within cluster).
    """
    df_idx = pd.Index(index)
    S_masked = _mask_similarity(S, threshold)
    Q = S_masked.shape[0]
    clusters = pd.Series(labels, index=df_idx, name="cluster")

    rep_vals = np.zeros(Q, dtype=np.float32)
    for cid in np.unique(labels):
        gidx = np.where(labels == cid)[0]
        if len(gidx) <= 1:
            rep_vals[gidx] = float(singleton_value)
            continue

        subS = S_masked[np.ix_(gidx, gidx)]
        np.fill_diagonal(subS, 0.0)
        denom = (subS > 0).sum(axis=1)
        denom = np.where(denom == 0, 1, denom) 
        rep_cluster = subS.sum(axis=1) / denom
        rep_vals[gidx] = rep_cluster

    typicality = pd.Series(rep_vals, index=df_idx, name="typicality")

    rows = []
    for cid, group in typicality.groupby(clusters):
        top_i = group.idxmax()
        rows.append({
            "cluster": int(cid),
            "representative_probe": top_i,
            "typicality": float(group.loc[top_i]),
            "cluster_size": int((labels == cid).sum())
        })
    reps_df = pd.DataFrame(rows).sort_values(
        ["cluster_size", "typicality"], ascending=[False, False]
    ).reset_index(drop=True)

    return typicality, reps_df, clusters



def compute_bridge(
    S: np.ndarray,
    labels: np.ndarray,
    index: list,
    threshold: float = 0.9,
):
    # Compute bridge score for each item (cross-cluster connectivity).
    df_idx = pd.Index(index)
    S0 = _mask_similarity(S, threshold)
    labels = np.asarray(labels)
    Q = S0.shape[0]
    k_counts = np.zeros(Q, dtype=int)
    for i in range(Q):
        li = labels[i]
        row = S0[i]
        nbrs = np.where(row > 0.0)[0]
        if nbrs.size == 0:
            k_counts[i] = 0
            continue
        ext_nbrs = nbrs[labels[nbrs] != li]
        if ext_nbrs.size == 0:
            k_counts[i] = 0
            continue
        k_counts[i] = len(np.unique(labels[ext_nbrs]))
    positive_k = k_counts[k_counts > 0]
    if positive_k.size == 0:
        k_median = 1.0
    else:
        k_median = float(np.median(positive_k))
    coverage = np.zeros(Q, dtype=np.float32)
    strength_out_mean = np.zeros(Q, dtype=np.float32)
    for i in range(Q):
        li = labels[i]
        row = S0[i]
        nbrs = np.where(row > 0.0)[0]
        if nbrs.size == 0:
            coverage[i] = 0.0
            strength_out_mean[i] = 0.0
            continue
        ext_nbrs = nbrs[labels[nbrs] != li]
        if ext_nbrs.size == 0:
            coverage[i] = 0.0
            strength_out_mean[i] = 0.0
            continue
        strength_out_mean[i] = float(row[ext_nbrs].mean())
        k = k_counts[i]
        coverage[i] = float(k) / (float(k) + k_median)
    bridge_vals = coverage * strength_out_mean
    bridge = pd.Series(bridge_vals, index=df_idx, name="bridge")
    participation = pd.Series(coverage, index=df_idx, name="participation")
    strength = pd.Series(strength_out_mean, index=df_idx, name="strength")
    return bridge, participation, strength


def compute_cluster_intra_stats(S: np.ndarray, labels: np.ndarray, threshold: float = 0.9) -> pd.DataFrame:
    # Compute basic per-cluster stats (size, intra-cluster mean similarity).
    S0 = _mask_similarity(S, threshold)
    rows = []
    for c in np.unique(labels):
        gidx = np.where(labels == c)[0]
        m = len(gidx)
        if m <= 1:
            rows.append({"cluster": int(c), "size": m, "intra_mean": 0.0})
            continue
        subS = S0[np.ix_(gidx, gidx)]
        triu = np.triu_indices(m, k=1)
        vals = subS[triu]
        rows.append({
            "cluster": int(c),
            "size": m,
            "intra_mean": float(vals.mean()) if vals.size else 0.0
        })
    return pd.DataFrame(rows).sort_values(["size", "intra_mean"], ascending=[False, False]).reset_index(drop=True)


def compute_clustering_metrics(df: pd.DataFrame, sim_threshold: float = 0.9, return_all: bool = True):
    # Full clustering and network analysis pipeline.
    groups, unique_idx, _ = _group_identical_rows(df)
    S, idx = phi_similarity_matrix(df)
    labels = cluster_with_leiden(S, groups, unique_idx, threshold=sim_threshold)
    typicality, reps_df, clusters = compute_typicality(S, labels, idx, threshold=sim_threshold)
    bridge, participation, strength = compute_bridge(S, labels, idx, threshold=sim_threshold)
    cluster_stats = compute_cluster_intra_stats(S, labels, threshold=sim_threshold)
    result_df = pd.concat(
        [clusters, typicality, bridge, participation, strength],
        axis=1
    )
    if return_all:
        return {
            "result_df": result_df,
            "cluster_stats": cluster_stats,
            "representatives": reps_df,
            "labels": labels,
            "similarity_matrix": S
        }
    else:
        return result_df
    

def compute_mean_jaccard(df: pd.DataFrame) -> pd.Series:
    # Compute mean Jaccard similarity for each item (error pattern).
    E = (1 - df.values.astype(np.int8)).astype(np.uint8)
    Q = E.shape[0]
    inter = E @ E.T
    row_sum = E.sum(axis=1)[:, None]
    union = row_sum + row_sum.T - inter
    with np.errstate(divide='ignore', invalid='ignore'):
        W = np.where(union > 0, inter / union, 0.0)
    np.fill_diagonal(W, 0.0)
    mean_jaccard = W.sum(axis=1) / (Q - 1)
    return pd.Series(mean_jaccard, index=df.index, name="mean_jaccard")


def analyze_test_items(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    # Compute all core metrics and return main and debug DataFrames.
    difficulty = compute_difficulty(df) # attr
    uniqueness = compute_uniqueness(df) # attr
    risk = compute_risk(df) # attr
    surprise_df = compute_surprise(df) # attr
    mean_jaccard = compute_mean_jaccard(df) # for debug
    clustering_df = compute_clustering_metrics(df) # attr
    cluster_result_df = clustering_df['result_df']
    main_df = pd.concat([
        difficulty,
        uniqueness,
        risk,
        mean_jaccard,
        surprise_df['surprise'],
        cluster_result_df['typicality'],
        cluster_result_df['bridge']
    ], axis=1)
    debug_df = pd.concat([
        difficulty,
        uniqueness,
        risk,
        surprise_df,
        cluster_result_df
    ], axis=1)
    return main_df, debug_df
