import numpy as np
import pandas as pd
from typing import Iterable, Optional, Tuple, List
from scipy.stats import norm


def logit_rank(s: pd.Series, tau: float = 0.2) -> pd.Series:
    s = s.astype(float)
    idx = s.index
    if s.nunique(dropna=True) <= 1:
        return pd.Series(np.full(len(s), 0.5), index=idx, dtype=float)
    
    r = s.rank(method="average", na_option="keep")
    u = (r - 0.5) / len(r)  
    
    # logit+sigmoid
    z = (u - 0.5) / max(tau, 1e-12)
    w = 1.0 / (1.0 + np.exp(-z))   # (0,1)
    
    return pd.Series(w, index=idx, dtype=float)


def weighted_acc(w: pd.Series, X: pd.DataFrame) -> pd.Series:
    w = w.reindex(X.index).fillna(0.0).astype(float)
    den = w.sum()
    if (w < 0).any():
        raise ValueError("All weights in 'w' must be non-negative.")
    return (X.mean(axis=0) if den <= 0 else (X.mul(w, axis=0).sum(axis=0) / den)).astype(float)


def add_ranks(scores: pd.DataFrame, cols: list[str], rank_method: str = 'min', suffix: str = '_rank') -> None:
    for c in cols:
        if c in scores.columns:
            r = scores[c].rank(ascending=False, method=rank_method)
            scores[c + suffix] = r.where(r.isna(), r.astype(int))


def compute_metric_rank_correlations(
    scores: pd.DataFrame,
    metrics: Optional[List[str]] = None,
    acc_rank_col: str = "ACC_rank",
    method: str = "pearson",
) -> pd.DataFrame:
    if metrics is None:
        metrics = ['RKSP','TYDF','BRDF','UQDF','SP','Composite']

    if acc_rank_col not in scores.columns:
        return pd.DataFrame(columns=['corr_with_ACC_rank'])

    acc_r = scores[acc_rank_col].astype(float)
    rows = []
    for m in metrics:
        rcol = f"{m}_rank"
        if rcol in scores.columns:
            corr = float(acc_r.corr(scores[rcol].astype(float), method=method))
            rows.append({"metric": m, "corr_with_ACC_rank": corr})

    corr_df = pd.DataFrame(rows)
    if not corr_df.empty:
        corr_df = corr_df.set_index("metric")
    return corr_df


def calculate_metrics(
    X: pd.DataFrame,
    attrs: pd.DataFrame,
    beta: float = 0.5,
    shrink_clip: Tuple[float, float] = (0.0, 1.0),
    include_acc: bool = True,
    X_full: Optional[pd.DataFrame] = None,
    acc_items: Optional[Iterable] = None,
    rank_method: str = 'min',
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:

    attrs = attrs.reindex(X.index)

    df  = logit_rank(attrs['difficulty'])
    rk  = logit_rank(attrs['risk'])
    sp  = logit_rank(attrs['surprise'])
    ty  = logit_rank(attrs['typicality'])
    br  = logit_rank(attrs['bridge'])
    uq  = logit_rank(attrs['uniqueness'])

    cluster = attrs['cluster']
    size = cluster.map(cluster.value_counts()).astype(float)

    shrink = (size ** (-beta))
    if shrink_clip is not None:
        lo, hi = shrink_clip
        shrink = shrink.clip(lower=lo, upper=hi)

    w_RKSP = rk * sp            # risk * surprise
    w_TYDF = ty * df * shrink   # typicality * diff * shrink
    w_BRDF = br * df            # bridge * diff
    w_UQDF = uq * df            # uniqueness * diff
    w_SP   = sp                 # surprise

    RKSP = weighted_acc(w_RKSP, X)
    TYDF = weighted_acc(w_TYDF, X)
    BRDF = weighted_acc(w_BRDF, X)
    UQDF = weighted_acc(w_UQDF, X)
    SPW  = weighted_acc(w_SP,   X)

    scores = pd.DataFrame({
        'RKSP': RKSP, 'TYDF': TYDF, 'BRDF': BRDF, 'UQDF': UQDF, 'SP': SPW,
    })
    scores['Composite'] = scores[['RKSP','TYDF','BRDF','UQDF','SP']].mean(axis=1, skipna=True)

    metric_cols = ['RKSP','TYDF','BRDF','UQDF','SP','Composite']

    if include_acc:
        if X_full is not None:
            X_acc = X_full
        elif acc_items is not None:
            acc_idx = X.index.intersection(pd.Index(list(acc_items)))
            X_acc = X.loc[acc_idx] if len(acc_idx) > 0 else X
        else:
            X_acc = X
        scores['ACC'] = X_acc.mean(axis=0).astype(float)
        metric_cols = ['ACC'] + metric_cols

    add_ranks(scores, cols=metric_cols, rank_method=rank_method, suffix='_rank')

    ordered_cols = []
    for c in metric_cols:
        if c in scores.columns:
            ordered_cols += [c, c + '_rank']
    rest = [c for c in scores.columns if c not in ordered_cols]
    scores = scores.reindex(columns=ordered_cols + rest)

    sort_key = 'Composite' if 'Composite' in scores.columns else metric_cols[0]
    scores = scores.sort_values(sort_key, ascending=False)

    corr_df = compute_metric_rank_correlations(
        scores,
        metrics=['RKSP','TYDF','BRDF','UQDF','SP','Composite'],
        acc_rank_col='ACC_rank',
        method='pearson',  
    )

    item_weights = pd.DataFrame({
        'difficulty_scale': df,
        'risk_scale': rk,
        'surprise_scale': sp,
        'typicality_scale': ty,
        'bridge_scale': br,
        'uniqueness_scale': uq,
        'cluster': cluster,
        'cluster_size': size,
        'shrink': shrink,
        'w_RKSP': w_RKSP,
        'w_TYDF': w_TYDF,
        'w_BRDF': w_BRDF,
        'w_UQDF': w_UQDF,
        'w_SP':   w_SP,
    }, index=X.index)

    return scores, item_weights, corr_df
