import itertools
import time

from tf_agents.environments import tf_py_environment
import sys

sys.path.append("../")
from distributed_locking.verifier import SMVerifier
from distributed_locking.verifier import distributed_locking
from distributed_locking.State import State
from MultiSM import MultiStateMachine
from checkpoint import restore_policy
from distributed_locking.DistributedLocking import reset_envs, step_envs
from distributed_locking.DistributedLockingEnv import DistributedLockingEnv
from distributed_locking.State import meaning_mapping
from utils import log


## NOT use z3 solver to verify the model for now
def informal_verify(policy, envs, tf_envs, combs, players, protocol, is_testall, failed_thresh, is_log=False):
    this_func = "distributed_locking/eval.py: informal verify"
    if combs is None:
        return True, set()
    ret = True
    failed_cnt = 0
    failed_cases = set()
    if is_log:
        print(f"Verifying {protocol}")
    #TODO: pass existing envs into this function so that we don't need to create new envs
    if envs is None:
        envs = [None] * players
        tf_envs = [None] * players
        multi_sm = MultiStateMachine(players, protocol)
        for i in range(players):
            envs[i] = DistributedLockingEnv(players, i, multi_sm, conf=None)
            tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])
    # traverse all possible input space

    for i in combs:
        reset_envs(envs, players, input=i)
        time_steps = [None] * players
        next_time_steps = [None] * players
        for p_idx in range(players):
            time_steps[p_idx] = tf_envs[p_idx].reset()
        if is_log:
            print("----------------")
        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            for p_idx in range(players):
                action_step = policy.action(time_steps[p_idx])
                action_steps.append(action_step)
                actions.append(int(action_step.action))
                if is_log:
                    # print(policy.wrapped_policy._q_network(time_steps[p_idx].observation))
                    print(f"{','.join(map(str, time_steps[p_idx].observation.numpy().squeeze()))}:{int(action_step.action)}")
                    # print(f"{list(map(lambda x : meaning_mapping[int(x[1])] if x[0] != 0 else x[1], enumerate(time_steps[p_idx].observation.numpy().squeeze())))} -> "
                    #       f"{meaning_mapping[int(action_step.action)]}")
            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players)
            for p_idx in range(players):
                next_time_steps[p_idx] = tf_envs[p_idx].step(action_steps[p_idx])

            time_steps = next_time_steps[:]

        # check final time_steps to check if the model is correct
        for ts in time_steps:
            if ts.reward < 0:
                if not is_testall:
                    return False, None
                else:
                    ret = False
                    failed_cnt += 1
                    failed_cases.add(i)
                    break

        if is_testall and failed_cnt == failed_thresh:
            break

    if is_testall:
        assert(failed_cnt == len(failed_cases))
    return ret, failed_cases

def generate_transitions(players, policy, env, tf_env):
    input = [State.NoNeed.value, State.Need.value]
    state_mapping = {s.value: s.name for s in State}
    combs = itertools.product(input, repeat=players)
    transitions = []
    for comb in combs:
        for p in range(players):
            env.states = [p] + list(comb)
            time_step = tf_env.reset()
            action_step = policy.action(time_step)
            action = int(action_step.action)
            transitions.extend([[[p] + list(map(lambda x : state_mapping[x], comb)), state_mapping[action]]])
    return transitions
def verify(policy, envs, tf_envs, players, protocol, is_log=False):
    if envs is None:
        envs = [None] * players
        tf_envs = [None] * players
        multi_sm = MultiStateMachine(players, protocol)
        for i in range(players):
            envs[i] = DistributedLockingEnv(players, i, multi_sm, conf=None)
            tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])

    ver = SMVerifier(State, players, mode="node_dependent")
    ver.add_transitions(generate_transitions(players, policy, envs[0], tf_envs[0]))
    ver.assert_protocol(
            distributed_locking, # function that asserts properties of the protocol
            1, # number of rounds
            [ver.get_type("Need"), ver.get_type("NoNeed")], # possible states for nodes to start in
            None, # types that represent lost messages
            [ver.get_type("Enter"), ver.get_type("NoEnter")]
    )
    if is_log:
        return ver.verify()
    return ver.verify_nolog()

if __name__ == "__main__":
    players = 9
    file = f"../chkpt/dl_{players}p_v{sys.argv[1]}_policy"
    policy = restore_policy(file)
    s = time.time()
    ret = informal_verify(policy, None, None, players, "distributed_locking", True, -1, True)
    e = time.time()
    print(f"elapsed time: {e-s} seconds")
    print(ret)