import sys
import time

from tf_agents.environments import tf_py_environment
sys.path.append("../")
from MultiSM import MultiStateMachine
from checkpoint import restore_policy
from simple_counter.Counter import reset_envs, step_envs
from simple_counter.CounterEnv import CounterEnv
from simple_counter.ComplexCounterEnv import ComplexCounterEnv
from utils import log


def informal_verify(policy, envs, tf_envs, players, protocol, is_testall, failed_thresh, is_log=False):
    this_func = "distributed_locking/eval.py: informal verify"
    ret = True
    failed_cnt = 0
    failed_cases = set()
    if is_log:
        print(f"Verifying {protocol}")
    # TODO: pass existing envs into this function so that we don't need to create new envs
    if envs is None:
        envs = [None] * players
        tf_envs = [None] * players
        multi_sm = MultiStateMachine(players, protocol)
        for i in range(players):
            envs[i] = ComplexCounterEnv(players, i, multi_sm, conf=None)
            tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])
    # traverse all possible input space

    for i in range(2 ** players):
        reset_envs(envs, players, input=i)
        time_steps = [None] * players
        next_time_steps = [None] * players
        for p_idx in range(players):
            time_steps[p_idx] = tf_envs[p_idx].reset()
        if is_log:
            print("----------------")
        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            for p_idx in range(players):
                action_step = policy.action(time_steps[p_idx])
                action_steps.append(action_step)
                actions.append(int(action_step.action))
            # ONLY need to print one combination becasue they are all same
            if is_log:
                # print(policy.wrapped_policy._q_network(time_steps[p_idx].observation))
                print(
                        f"{','.join(map(str, time_steps[0].observation.numpy().squeeze()))}:{int(action_step.action)}")
                # print(f"{list(map(lambda x : meaning_mapping[int(x[1])] if x[0] != 0 else x[1], enumerate(time_steps[p_idx].observation.numpy().squeeze())))} -> "
                #       f"{meaning_mapping[int(action_step.action)]}")
            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players)
            for p_idx in range(players):
                next_time_steps[p_idx] = tf_envs[p_idx].step(action_steps[p_idx])

            time_steps = next_time_steps[:]

        # check final time_steps to check if the model is correct
        for ts in time_steps:
            if ts.reward < 0:
                if not is_testall:
                    return False, None
                else:
                    ret = False
                    failed_cnt += 1
                    failed_cases.add(i)
                    break

        if is_testall and failed_cnt == failed_thresh:
            break

    if is_testall:
        assert (failed_cnt == len(failed_cases))
        print(f"failed at least {failed_cnt} cases")
    return ret, failed_cases

if __name__ == "__main__":
    players = 3
    file = f"../chkpt/complex_counter_transformer_{players}p_v{sys.argv[1]}_policy"
    policy = restore_policy(file)
    s = time.time()
    ret = informal_verify(policy, None, None, players, "counter", "test_all", -1, True)
    e = time.time()
    print(f"elapsed time: {e-s} seconds")
    print(ret)