from automata_register import AUTOMATA_REGISTER
from counting_sampling import sample_state


AUTOMATA_NAME = "canonical_parity"
MACHINE = AUTOMATA_REGISTER[AUTOMATA_NAME]
ACCEPT_PROB=0.01

NUM_SAMPLES = 1000
NUM_TARGETS = 500
SEED=0
STATE_NAME = "q2"
TARGET_STATE_ID = MACHINE.state_map_rev[STATE_NAME]


def get_corpus():
    train_samples, sampler, counts, stats = sample_state(
        NUM_SAMPLES,
        NUM_TARGETS,
        seed=SEED,
        accept_prob=ACCEPT_PROB,
        name=AUTOMATA_NAME,
        tgt_state=TARGET_STATE_ID
    )
    return train_samples, sampler, counts, stats


def get_vanilla_corpus(sampler, test_size):
    test_samples = sampler.sample_original(test_size)
    return test_samples


if __name__ == "__main__":
    # Sample corpus under intervention
    train_samples, sampler, counts, stats = get_corpus()
    train_unique = set()

    assert len(train_samples) == NUM_SAMPLES
    assert sum(counts) == NUM_TARGETS

    seen_target = 0
    seen_target_strings = 0
    for sample in train_samples:
        train_unique.add(sample["sampled_string"])

        if TARGET_STATE_ID in sample["sampled_string"]:
            seen_target_strings += 1

        for arc in sample["transitions"]:
            src, tgt, symb = arc
            if tgt == TARGET_STATE_ID:
                seen_target += 1
        
    assert seen_target == NUM_TARGETS
    print(f"Avg tgt occs {seen_target/seen_target_strings}")
    print(f"Unique train strings: {len(train_unique)}")

    # Check that targeted arcs have the state on the outgoing?
    for arc, wgt in list(sampler.A_l.transition_weights()):
        if arc.state_to == TARGET_STATE_ID:
            assert not wgt[1].isinf()
            assert wgt[0].isinf()
        else:
            assert not wgt[0].isinf()
            assert wgt[1].isinf()     


    test_samples = get_vanilla_corpus(sampler, 4000)
    seen_target_prefixes = []
    for t_samp in test_samples:
        pref = []
        for arc in t_samp["transitions"]:
            src, tgt, symn = arc
            pref.append(arc)
            if tgt == TARGET_STATE_ID:
                seen_target_prefixes.append(tuple(pref))

    unique_prefixes = set(seen_target_prefixes)
    print(f"Seen in test data {len(seen_target_prefixes)}")
    print(f"Unique seen in test data {len(unique_prefixes)}")    

    breakpoint()