from lrayuela.fsa import sampler
from lrayuela.base.symbol import EOS, Sym
import random


def zero_arc(A, i, a, j):
    zero = A.R.zero
    A.δ[i][a][j] = zero
    A.δ_inv[j][a][i] = zero
    return A


def get_arcs(A):
    arcs = []
    for state in A.Q:
        for symbol in A.Sigma:
            for arc in A.a_out_arcs(state, symbol):
                target, weight = arc
                if weight.value == 0:
                    continue
                arcs.append((state, symbol, target, weight))
    return arcs


def sample_from_machine_with_arcs_interventions(
    A, tgt_arc, N=None, test_N=None, seed=0, arc_occs=500
):
    # all_arcs = get_arcs(A)
    # arc_count = len(all_arcs)

    # if N is None:
    #    N = arc_count * 100

    # Max occurrences per arc is some sort of average
    # max_arc_occ = N // arc_count
    # We test interventions with num_bins bins of occurrences up to the average
    # num_bins = 5
    # arc_occ_step = max_arc_occ // num_bins

    # for arc_idx, arc in enumerate(all_arcs):
    state, symbol, target, weight = tgt_arc
    # print(
    #    f"arc-{arc_idx} / occ-{arc_occ} / max-{max_arc_occ}: Intervening on arc {arc}"
    # )

    a_smplr = sampler.Sampler(A, seed=seed)
    try:
        test_smpls = a_smplr.sample(test_N)
    except:
        return [], [], []

    # Get samples without arc
    na_A = A.copy()
    na_A = zero_arc(na_A, state, symbol, target)
    na_A = na_A.trim()
    na_A = na_A.normalize()

    na_smplr = sampler.Sampler(na_A, seed=seed)
    try:
        na_smpls = na_smplr.sample(N - arc_occs)
    except:
        na_smpls = []
        return [], [], []

    # Get samples with arc
    # First prefixes
    prf_A = A.copy()
    # Set the target state to only end state
    for s, w in prf_A.ρ.items():
        if s == target:
            prf_A.ρ[s] = A.R.one
        else:
            prf_A.ρ[s] = A.R.zero
    prf_A = prf_A.trim()
    prf_A = prf_A.normalize()
    prf_smplr = sampler.Sampler(prf_A, seed=seed)
    prf_smpls = prf_smplr.sample(arc_occs)

    # Then suffixes
    sfx_A = A.copy()
    # Set the target to only start state
    sfx_A.λ = {s: A.R.one if s == target else A.R.zero for s in sfx_A.Q}
    sfx_A = sfx_A.trim()
    sfx_A = sfx_A.normalize()
    sfx_smplr = sampler.Sampler(sfx_A, seed=seed)
    sfx_smpls = sfx_smplr.sample(arc_occs)

    # Combine prf and sfx
    smpls = []
    for prf, sfx in zip(prf_smpls, sfx_smpls):
        smpls.append(f"{prf} {sfx}")

    assert len(smpls) == arc_occs
    assert len(smpls) + len(na_smpls) == N
    assert len(na_smpls) == N - arc_occs

    return na_smpls, smpls, test_smpls


