import numpy as np
from typing import Dict, List, Callable

from sampling import make_linear_extension_sampler


def estimate_poset_shapley_with_rank(
    *,
    P,
    lamb,
    feat_names: List[str],
    v_global_fn: Callable[[List[str]], float],
    M: int = 1000,
    seed: int = 0,
    laziness: float = 0.1,
    burnin_steps: int = 20000,
    steps_between: int = 1000,
    verbose_every: int = 200,
):
    draw_perm, mcmc_state = make_linear_extension_sampler(
        P=P,
        lamb=lamb,
        seed=seed,
        laziness=laziness,
        burnin_steps=burnin_steps,
        steps_between=steps_between,
    )

    F = len(feat_names)
    phi = np.zeros(F, dtype=np.float64)
    
    rank_sum = np.zeros(F, dtype=np.float64)

    V: Dict[frozenset, float] = {}
    
    linear_extensions = []
    utility_values = []

    def v_cached(S_list: List[str]) -> float:
        key = frozenset(S_list)
        if key in V:
            return V[key]
        val = float(v_global_fn(S_list))
        V[key] = val
        return val

    v_empty = v_cached([])

    for t in range(int(M)):
        if (t % max(1, verbose_every)) == 0:
            print(f"[MC] {t}/{M}")
        perm = draw_perm()
        
        linear_extensions.append([feat_names[int(perm[pos])] for pos in range(F)])
        
        for pos in range(F):
            feat_idx = int(perm[pos])
            rank_sum[feat_idx] += (pos + 1)

        perm_utilities = [v_empty]
        
        S = []
        v_prev = v_empty
        for k in range(F):
            j = int(perm[k])
            S.append(feat_names[j])
            v_curr = v_cached(S)
            perm_utilities.append(v_curr)
            phi[j] += (v_curr - v_prev)
            v_prev = v_curr
        
        utility_values.append(perm_utilities)

    phi /= float(M)
    mean_ranks = rank_sum / float(M)
    acc_rate = mcmc_state["total_moves"] / max(1, mcmc_state["total_steps"])

    return {
        "phi": phi,
        "phi_dict": {feat_names[i]: float(phi[i]) for i in range(F)},
        "mean_ranks": mean_ranks,
        "mean_ranks_dict": {feat_names[i]: float(mean_ranks[i]) for i in range(F)},
        "acceptance_rate": float(acc_rate),
        "mcmc_state": mcmc_state,
        "V_cache_size": len(V),
        "linear_extensions": linear_extensions,
        "utility_values": utility_values,
    }









