from collections import defaultdict
from counting_sampling_v2 import sample_state
import math
import numpy as np
import torch
import math

TEST_AUTOMATON = "canonical_parity"
DEV_THRESHOLD = 10


def check_machine(A_l):
    trans = A_l._transitions
    by_state = defaultdict(list)

    for t, weight in trans.items():
        by_state[t.state_from].append(weight)

    acc_wgt = A_l._accept_weights
    
    for state, arcs in by_state.items():
        # Assert the log probabilities sum to 1
        # weight is a vector of log probabilities
        tot_prob = 0
        for weight in by_state[state]:
            tot_prob += np.exp(weight).sum().item()

        acc_prob = np.exp(acc_wgt.get(state, torch.tensor(-float('inf')))).sum().item() 
        tot_prob += acc_prob    

        assert abs(tot_prob - 1) < 1e-6, f"State {state} has total probability {tot_prob}, should be 1"

    return True


def count_in_samples(samples, tgt):
    #
    # Returns the number of occurrences of tgt
    #

    lengths = []
    found = 0
    for samp in samples:
        for trans in samp["transitions"]:
            if trans[0] == tgt:
                #lengths.append(len(samp["sampled_string"]))
                found += 1

    return found # len(lengths)


def get_len_target(samples, tgt):
    #
    # Returns the mean length of strings containing tgt
    # 

    lengths = []
    for samp in samples:
        for trans in samp["transitions"]:
            if trans[0] == tgt:
                lengths.append(len(samp["sampled_string"]))
                break
    return np.mean(lengths)


def get_rejection(sampler, K, N, tgt):
    found = None
    it = 0
    seen = []
    while found != N:
        it += 1
        samples = sampler.sample_original(K)
        found = count_in_samples(samples, tgt)
        seen.append(found)
        if it % 100 == 0:
            print(f" | - Done iterations: {it}. Seen avg {np.mean(seen).item()}, cur {found}")

    return samples
        

def test_intervention_symbol():
 
    K=1000
    N=300
    tgt_state = 4

    mean_interv = []
    mean_orig = []
    samplers = []

    for i in range(100,110):
        samples, sampler, counts, stats = sample_state(K, N, name="canonical_parity", tgt_state=tgt_state, seed=i)
        sampler.set_seed(i)
        samplers.append(sampler)

        mean_len = get_len_target(samples, tgt_state)
        mean_interv.append(mean_len)

        orig_samples = sampler.sample_original(K)
        mean_len = get_len_target(orig_samples, tgt_state)
        mean_orig.append(mean_len)

    mean_interv_mean = np.mean([a for a in mean_interv if not math.isnan(a)])
    mean_orig_mean = np.mean([a for a in mean_orig if not math.isnan(a)])
    print(" -- mean-intervs:   ", mean_interv)
    print(" -- mean-origs:     ", mean_orig)
    print("Mean over 10 runs: ", mean_interv_mean, np.std(mean_interv))
    print("Mean orig over 10 runs: ", mean_orig_mean)

    rejections = []
    for sampler in samplers:
        orig_samples = get_rejection(sampler, K, N, tgt_state)
        mean_orig = get_len_target(orig_samples, tgt_state)
        rejections.append(mean_orig) 
        print(" | Mean with rejection: ", mean_orig)
    print("Mean mean rejection: ", np.mean(rejections))


    breakpoint()

if __name__ == "__main__":
    test_intervention_symbol()