import numpy as np
import math
from mcmc import make_mcmc_state, mcmc_take
from poset import topo_sort


def make_init_perm(P, lamb, booster_ids, rng=None):
    if rng is None:
        rng = np.random
    n = P["n"]
    booster_ids = np.asarray(booster_ids, dtype=np.int64)
    booster_set = set(int(x) for x in booster_ids.tolist())
    
    non_booster = np.array([i for i in range(n) if i not in booster_set], dtype=np.int64)
    
    non_booster_set = set(int(x) for x in non_booster.tolist())
    old_to_new = {int(old): new for new, old in enumerate(non_booster.tolist())}
    
    preds_nb = []
    succs_nb = []
    for old in non_booster.tolist():
        old = int(old)
        preds_old = P["preds"][old]
        succs_old = P["succs"][old]
        preds_nb.append(set(old_to_new[int(p)] for p in preds_old if int(p) in non_booster_set))
        succs_nb.append(set(old_to_new[int(s)] for s in succs_old if int(s) in non_booster_set))
    
    P_nb = {"n": int(non_booster.size), "preds": preds_nb, "succs": succs_nb}
    
    perm_nb = topo_sort(P_nb, rng)
    perm_nb_original = non_booster[perm_nb]
    
    perm_booster = rng.permutation(booster_ids)
    
    perm0 = np.concatenate([perm_nb_original, perm_booster]).astype(np.int64)
    return perm0

def sample_linear_extension(
    method,
    *,
    n_all,
    owner_ids,
    non_owner_ids,
    P_full,
    lam_uniform,
    lam_priority,
    limit_case="none",
    G_ids=None,
    anchor_ids=None,
    U_a_ids=None,
    rest_ids=None,
    state_PSV=None,
    state_PASV=None,
    steps_between=50,
    rng=None,
):
    if rng is None:
        rng = np.random
    if method == "SV":
        return rng.permutation(int(n_all)).astype(np.int64)
    if method == "WSV":
        perm_owner = rng.permutation(np.asarray(owner_ids, dtype=np.int64))
        perm_rest = rng.permutation(np.asarray(non_owner_ids, dtype=np.int64))
        return np.concatenate([perm_owner, perm_rest]).astype(np.int64)
    if method == "PSV":
        if state_PSV is None:
            raise ValueError("state_PSV is required for method='PSV'")
        perm = mcmc_take(state_PSV, int(steps_between))
        return perm.astype(np.int64)
    if method == "PASV":
        if limit_case is None:
            limit_case = "none"
        limit_case = str(limit_case)
        if limit_case != "none":
            G_ids = np.asarray(G_ids, dtype=np.int64)
            rest_ids = np.asarray(rest_ids, dtype=np.int64)

            perm_rest_new = mcmc_take(state_PSV, int(steps_between))
            perm_rest_new = np.asarray(perm_rest_new, dtype=np.int64)
            perm_rest = rest_ids[perm_rest_new]
            perm_G = rng.permutation(G_ids)
            perm = np.concatenate([perm_rest, perm_G]).astype(np.int64)
            return perm

        perm = mcmc_take(state_PASV, int(steps_between))
        return perm.astype(np.int64)


def make_sampler_kwargs(
    method,
    rng,
    *,
    n_all,
    owner_ids,
    non_owner_ids,
    P_full,
    lam_uniform,
    lam_priority,
    limit_case="none",
    G_ids=None,
    anchor_ids=None,
    U_a_ids=None,
    rest_ids=None,
    steps_between=1000,
    burnin_steps=20000,
    lam_exponents=None,
    booster_ids=None,
):
    kwargs = dict(
        n_all=n_all,
        owner_ids=owner_ids,
        non_owner_ids=non_owner_ids,
        P_full=P_full,
        lam_uniform=lam_uniform,
        lam_priority=lam_priority,
        limit_case=limit_case,
        G_ids=G_ids,
        anchor_ids=anchor_ids,
        U_a_ids=U_a_ids,
        rest_ids=rest_ids,
        state_PSV=None,
        state_PASV=None,
        steps_between=int(steps_between),
    )
    if method in {"PSV", "PASV"}:
        if method == "PASV" and str(limit_case) in {"booster", "copier", "poisoner"}:
            state = make_mcmc_state(
                P_full,
                lam_uniform,
                rng,
                laziness=0.1,
                use_prefix_cache=False,
            )
            if burnin_steps > 0:
                mcmc_take(state, int(burnin_steps))
            kwargs["state_PSV"] = state
        elif method == "PSV":
            state = make_mcmc_state(
                P_full,
                lam_uniform,
                rng,
                laziness=0.1,
                use_prefix_cache=False,
            )
            if burnin_steps > 0:
                mcmc_take(state, int(burnin_steps))
            kwargs["state_PSV"] = state
        elif method == "PASV":
            init_perm = None
            if (limit_case == "none" and 
                lam_exponents is not None and 
                booster_ids is not None and 
                len(booster_ids) > 0):
                try:
                    exp_vec = [float(x.strip()) for x in str(lam_exponents).split(",") if x.strip() != ""]
                    if len(exp_vec) == 5:
                        exp_owner, exp_anchor, exp_booster, exp_copier, exp_poisoner = exp_vec
                        if (exp_owner == 0.0 and exp_anchor == 0.0 and exp_booster == 1.0 and 
                            exp_copier == 0.0 and exp_poisoner == 0.0):
                            init_perm = make_init_perm(P_full, lam_priority, booster_ids, rng)
                except:
                    pass
            
            state = make_mcmc_state(
                P_full,
                lam_priority,
                rng,
                laziness=0.1,
                use_prefix_cache=True,
                init_perm=init_perm,
            )
            if burnin_steps > 0:
                mcmc_take(state, int(burnin_steps))
            kwargs["state_PASV"] = state
    return kwargs


