import pandas as pd
import numpy as np
import igraph as ig
import leidenalg as la
from typing import Optional

def compute_difficulty(df: pd.DataFrame) -> pd.Series:
    # Compute difficulty score for each item (error rate).
    difficulty_scores = 1 - df.mean(axis=1)
    difficulty_scores.name = "difficulty"
    return difficulty_scores



def compute_uniqueness(df: pd.DataFrame) -> pd.Series:
    print("new fast")
    X = df.values
    if X.dtype != np.uint8:
        X = (X > 0).astype(np.uint8)
    X = X.astype(np.float32)                    # (Q,N)
    Q, N = X.shape

    a = X @ X.T                                 # (Q,Q) n11
    r = X.sum(axis=1, dtype=np.float32)         # (Q,)

    b = r[:, None] - a                          # n10
    c = r[None, :] - a                          # n01
    d = N - a - b - c                           # n00

    a = np.clip(a, 0.0, N)
    b = np.clip(b, 0.0, N)
    c = np.clip(c, 0.0, N)
    d = np.clip(d, 0.0, N)

    eps = 1e-12
    n1 = a + b                              
    n0 = c + d                        

    p1 = np.where(n1 > 0, a / (n1 + eps), 0.5).astype(np.float32)  # P(Vj=1 | Vi=1)
    p0 = np.where(n0 > 0, c / (n0 + eps), 0.5).astype(np.float32)  # P(Vj=1 | Vi=0)

    def _H(p):
        p = np.asarray(p, dtype=np.float32)
        p = np.clip(p, 0.0, 1.0)

        mid = (p > 0.0) & (p < 1.0)

        H = np.zeros_like(p, dtype=np.float32)
        if np.any(mid):
            pm = np.clip(p[mid], eps, 1.0 - eps)
            H[mid] = -(pm * np.log2(pm) + (1.0 - pm) * np.log2(1.0 - pm))

        return H

    H1 = _H(p1)                                  # H(Vj|Vi=1)
    H0 = _H(p0)                                  # H(Vj|Vi=0)

    P1 = (n1 / (N + eps)).astype(np.float32)     # P(Vi=1)
    P0 = (n0 / (N + eps)).astype(np.float32)     # P(Vi=0)

    H_cond = (P1 * H1 + P0 * H0).astype(np.float32)  # (Q,Q)
    np.fill_diagonal(H_cond, 0.0)

    uniq = (H_cond.sum(axis=1, dtype=np.float32) / (Q - 1.0)).astype(np.float32)
    return pd.Series(uniq, index=df.index, name="uniqueness")



def compute_risk(df: pd.DataFrame, eps: float = 1e-9) -> pd.Series:
    X = df.values
    if X.dtype != np.uint8:
        X = (X > 0).astype(np.uint8)
    E = (1 - X).astype(np.float32)               # (Q,N) errors
    Q, N = E.shape

    errs_per_model = E.sum(axis=0, dtype=np.float32)                 # (N,)
    w = np.log((Q + 1.0) / (errs_per_model + 1.0) + eps).astype(np.float32)

    Ew = (E * w[None, :]).astype(np.float32)                         # (Q,N)
    inter = Ew @ E.T.astype(np.float32)                              # (Q,Q)
    s = Ew.sum(axis=1, dtype=np.float32)                             # (Q,)
    union = (s[:, None] + s[None, :] - inter).astype(np.float32)

    with np.errstate(divide='ignore', invalid='ignore'):
        J = np.where(union > 0, inter / (union + eps), 0.0).astype(np.float32)

    np.fill_diagonal(J, 0.0)
    risk = (J.sum(axis=1, dtype=np.float32) / (Q - 1.0)).astype(np.float32)
    return pd.Series(risk, index=df.index, name="risk")


def compute_surprise(df: pd.DataFrame, eps: float = 1e-9, alpha: float = 1.0) -> pd.DataFrame:
    X = df.values
    if X.dtype != np.uint8:
        X = (X > 0).astype(np.uint8)
    Q, N = X.shape
    X = X.astype(np.float32)

    a = X.mean(axis=0, dtype=np.float32)         # (N,)
    r = (1.0 - a).astype(np.float32)

    # z-score
    a_mu, a_sigma = a.mean(dtype=np.float32), a.std(dtype=np.float32) + eps
    r_mu, r_sigma = r.mean(dtype=np.float32), r.std(dtype=np.float32) + eps
    a_z = ((a - a_mu) / a_sigma).astype(np.float32)
    r_z = ((r - r_mu) / r_sigma).astype(np.float32)

    n_right = X.sum(axis=1, dtype=np.float32)    # (Q,)
    n_wrong = (N - n_right).astype(np.float32)

    dot_wrong = (1.0 - X) @ a_z                  # sum a_z over wrong columns
    dot_right = X @ r_z                          # sum r_z over right columns

    tf_wrong = np.where(n_wrong > 0, dot_wrong / (n_wrong + eps), 0.0).astype(np.float32)
    tf_right = np.where(n_right > 0, dot_right / (n_right + eps), 0.0).astype(np.float32)

    idf_wrong = (np.log((N + 1.0) / (n_wrong + 1.0) + eps) ** alpha).astype(np.float32)
    idf_right = (np.log((N + 1.0) / (n_right + 1.0) + eps) ** alpha).astype(np.float32)

    se_tfidf = (tf_wrong * idf_wrong).astype(np.float32)
    ss_tfidf = (tf_right * idf_right).astype(np.float32)
    surprise_tfidf = (0.5 * (se_tfidf + ss_tfidf)).astype(np.float32)

    return pd.DataFrame({
        "n_wrong": n_wrong, "n_right": n_right,
        "tf_wrong": tf_wrong, "tf_right": tf_right,
        "idf_wrong": idf_wrong, "idf_right": idf_right,
        "surprise_error": se_tfidf, "surprise_success": ss_tfidf,
        "surprise": surprise_tfidf,
    }, index=df.index)



# ================================
# ======= cluster analysis =======
#  Representativeness and bridge
def _group_identical_rows(df: pd.DataFrame):
    X = df.values.astype(np.uint8)
    Q = X.shape[0]
    sig_to_gid = {}
    groups = []
    row_to_group = np.empty(Q, dtype=int)
    for i in range(Q):
        sig = X[i].tobytes()
        gid = sig_to_gid.get(sig)
        if gid is None:
            gid = len(groups)
            sig_to_gid[sig] = gid
            groups.append([i])
        else:
            groups[gid].append(i)
        row_to_group[i] = gid
    unique_idx = [g[0] for g in groups]
    return groups, unique_idx, row_to_group


def phi_similarity_matrix(df: pd.DataFrame, eps: float = 1e-12) -> tuple[np.ndarray, list]:
    X = df.values
    if X.dtype != np.uint8:
        X = (X > 0).astype(np.uint8)
    X = X.astype(np.float32)                     # (Q,N)
    Q, N = X.shape

    a = X @ X.T                                  # n11
    r = X.sum(axis=1, dtype=np.float32)          # (Q,)

    b = r[:, None] - a                           # n10
    c = r[None, :] - a                           # n01
    d = N - a - b - c                            # n00

    a = np.clip(a, 0.0, N)
    b = np.clip(b, 0.0, N)
    c = np.clip(c, 0.0, N)
    d = np.clip(d, 0.0, N)

    denom = np.sqrt((a + b) * (c + d) * (a + c) * (b + d) + eps).astype(np.float32)
    num = (a * d - b * c).astype(np.float32)
    phi = np.clip(num / denom, -1.0, 1.0).astype(np.float32)

    S = ((phi + 1.0) * 0.5).astype(np.float32)
    np.fill_diagonal(S, 1.0)
    return S, df.index.to_list()


def _mask_similarity(S: np.ndarray, thresh: float) -> np.ndarray:
    # Threshold similarity matrix.
    S_masked = S.copy()
    S_masked[S_masked < thresh] = 0.0
    np.fill_diagonal(S_masked, 0.0)
    return S_masked


def cluster_with_leiden(
    S: np.ndarray,
    groups: list[list[int]],
    unique_idx: list[int],
    threshold: float = 0.9,
) -> np.ndarray:
    # Leiden community detection on thresholded graph of unique patterns.    
    Su = S[np.ix_(unique_idx, unique_idx)].copy()
    Su = _mask_similarity(Su, threshold)
    iu, ju = np.where(np.triu(Su, k=1) > 0.0)
    weights = Su[iu, ju].astype(float)
    n_unique = len(unique_idx)
    g = ig.Graph(n=n_unique, edges=list(zip(iu.tolist(), ju.tolist())), directed=False)
    if len(weights) > 0:
        g.es["weight"] = weights
    if g.ecount() == 0:
        labels_unique = np.arange(n_unique, dtype=int)
    else:
        part = la.find_partition(
            g,
            la.RBConfigurationVertexPartition,
            weights=g.es["weight"] if g.ecount() > 0 else None,
            resolution_parameter=1.0,
            seed=123,
        )
        labels_unique = np.asarray(part.membership, dtype=int)
    Q = S.shape[0]
    labels_full = np.empty(Q, dtype=int)
    for gid, lbl in enumerate(labels_unique):
        for row_idx in groups[gid]:
            labels_full[row_idx] = lbl
    return labels_full


def compute_typicality(
    S: np.ndarray,
    labels: np.ndarray,
    index: list,
    threshold: float = 0.9,
    singleton_value: float = 0,
):
    """
    Compute typicality for each item (average similarity within cluster).
    """
    df_idx = pd.Index(index)
    S_masked = _mask_similarity(S, threshold)
    Q = S_masked.shape[0]
    clusters = pd.Series(labels, index=df_idx, name="cluster")

    rep_vals = np.zeros(Q, dtype=np.float32)
    for cid in np.unique(labels):
        gidx = np.where(labels == cid)[0]
        if len(gidx) <= 1:
            rep_vals[gidx] = float(singleton_value)
            continue

        subS = S_masked[np.ix_(gidx, gidx)]
        np.fill_diagonal(subS, 0.0)
        denom = (subS > 0).sum(axis=1)
        denom = np.where(denom == 0, 1, denom) 
        rep_cluster = subS.sum(axis=1) / denom
        rep_vals[gidx] = rep_cluster

    typicality = pd.Series(rep_vals, index=df_idx, name="typicality")

    rows = []
    for cid, group in typicality.groupby(clusters):
        top_i = group.idxmax()
        rows.append({
            "cluster": int(cid),
            "representative_probe": top_i,
            "typicality": float(group.loc[top_i]),
            "cluster_size": int((labels == cid).sum())
        })
    reps_df = pd.DataFrame(rows).sort_values(
        ["cluster_size", "typicality"], ascending=[False, False]
    ).reset_index(drop=True)

    return typicality, reps_df, clusters



def compute_bridge(
    S: np.ndarray,
    labels: np.ndarray,
    index: list,
    threshold: float = 0.9,
):
    # Compute bridge score for each item (cross-cluster connectivity).
    df_idx = pd.Index(index)
    S0 = _mask_similarity(S, threshold)
    labels = np.asarray(labels)
    Q = S0.shape[0]
    k_counts = np.zeros(Q, dtype=int)
    for i in range(Q):
        li = labels[i]
        row = S0[i]
        nbrs = np.where(row > 0.0)[0]
        if nbrs.size == 0:
            k_counts[i] = 0
            continue
        ext_nbrs = nbrs[labels[nbrs] != li]
        if ext_nbrs.size == 0:
            k_counts[i] = 0
            continue
        k_counts[i] = len(np.unique(labels[ext_nbrs]))
    positive_k = k_counts[k_counts > 0]
    if positive_k.size == 0:
        k_median = 1.0
    else:
        k_median = float(np.median(positive_k))
    coverage = np.zeros(Q, dtype=np.float32)
    strength_out_mean = np.zeros(Q, dtype=np.float32)
    for i in range(Q):
        li = labels[i]
        row = S0[i]
        nbrs = np.where(row > 0.0)[0]
        if nbrs.size == 0:
            coverage[i] = 0.0
            strength_out_mean[i] = 0.0
            continue
        ext_nbrs = nbrs[labels[nbrs] != li]
        if ext_nbrs.size == 0:
            coverage[i] = 0.0
            strength_out_mean[i] = 0.0
            continue
        strength_out_mean[i] = float(row[ext_nbrs].mean())
        k = k_counts[i]
        coverage[i] = float(k) / (float(k) + k_median)
    bridge_vals = coverage * strength_out_mean
    bridge = pd.Series(bridge_vals, index=df_idx, name="bridge")
    participation = pd.Series(coverage, index=df_idx, name="participation")
    strength = pd.Series(strength_out_mean, index=df_idx, name="strength")
    return bridge, participation, strength


def compute_cluster_intra_stats(S: np.ndarray, labels: np.ndarray, threshold: float = 0.9) -> pd.DataFrame:
    # Compute basic per-cluster stats (size, intra-cluster mean similarity).
    S0 = _mask_similarity(S, threshold)
    rows = []
    for c in np.unique(labels):
        gidx = np.where(labels == c)[0]
        m = len(gidx)
        if m <= 1:
            rows.append({"cluster": int(c), "size": m, "intra_mean": 0.0})
            continue
        subS = S0[np.ix_(gidx, gidx)]
        triu = np.triu_indices(m, k=1)
        vals = subS[triu]
        rows.append({
            "cluster": int(c),
            "size": m,
            "intra_mean": float(vals.mean()) if vals.size else 0.0
        })
    return pd.DataFrame(rows).sort_values(["size", "intra_mean"], ascending=[False, False]).reset_index(drop=True)


def compute_clustering_metrics(
    df: pd.DataFrame,
    sim_threshold: float = 0.9,
    return_all: bool = True,
):
    # Full clustering and network analysis pipeline.
    groups, unique_idx, _ = _group_identical_rows(df)
    S, idx = phi_similarity_matrix(df)
    labels = cluster_with_leiden(S, groups, unique_idx, threshold=sim_threshold)
    typicality, reps_df, clusters = compute_typicality(S, labels, idx, threshold=sim_threshold)
    bridge, participation, strength = compute_bridge(S, labels, idx, threshold=sim_threshold)
    cluster_stats = compute_cluster_intra_stats(S, labels, threshold=sim_threshold)

    result_df = pd.concat(
        [clusters, typicality, bridge, participation, strength],
        axis=1
    )

    n_clusters = result_df["cluster"].nunique()
    print(f"[compute_clustering_metrics] similarity_threshold={sim_threshold}, "
          f"#clusters={n_clusters}")

    if return_all:
        return {
            "result_df": result_df,
            "cluster_stats": cluster_stats,
            "representatives": reps_df,
            "labels": labels,
            "similarity_matrix": S,
        }
    else:
        return result_df

    

def compute_mean_jaccard(df: pd.DataFrame) -> pd.Series:
    # Compute mean Jaccard similarity for each item (error pattern).
    E = (1 - df.values.astype(np.int8)).astype(np.uint8)
    Q = E.shape[0]
    inter = E @ E.T
    row_sum = E.sum(axis=1)[:, None]
    union = row_sum + row_sum.T - inter
    with np.errstate(divide='ignore', invalid='ignore'):
        W = np.where(union > 0, inter / union, 0.0)
    np.fill_diagonal(W, 0.0)
    mean_jaccard = W.sum(axis=1) / (Q - 1)
    return pd.Series(mean_jaccard, index=df.index, name="mean_jaccard")



def analyze_test_items(
    df: pd.DataFrame,
    cluster_result_df: Optional[pd.DataFrame] = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Compute all core metrics and return main and debug DataFrames.
    """
    difficulty = compute_difficulty(df)       # attr
    uniqueness = compute_uniqueness(df)       # attr
    risk = compute_risk(df)                   # attr
    surprise_df = compute_surprise(df)        # attr
    mean_jaccard = compute_mean_jaccard(df)   # for debug

    if cluster_result_df is None:
        clustering_all = compute_clustering_metrics(df, return_all=True)
        cluster_result_df = clustering_all["result_df"]

    main_df = pd.concat([
        difficulty,
        uniqueness,
        risk,
        mean_jaccard,
        surprise_df['surprise'],
        cluster_result_df['typicality'],
        cluster_result_df['bridge'],
    ], axis=1)

    debug_df = pd.concat([
        difficulty,
        uniqueness,
        risk,
        surprise_df,
        cluster_result_df,
    ], axis=1)

    return main_df, debug_df
