import sys
import time

from tf_agents.environments import tf_py_environment

from tfagents.utils import parse_arg

sys.path.append("../")
from math_func.Math import reset_envs, step_envs
from math_func.MathEnv import MathEnv
from MultiSM import MultiStateMachine
from checkpoint import restore_policy
from utils import log


def informal_verify(policy, envs, tf_envs, combs, players, protocol, is_testall, failed_thresh, is_log=False):
    this_func = "distributed_locking/eval.py: informal verify"
    if combs is None:
        return True, set()
    ret = True
    failed_cnt = 0
    failed_cases = set()
    if is_log:
        print(f"Verifying {protocol}")
    # TODO: pass existing envs into this function so that we don't need to create new envs
    if envs is None:
        envs = [None] * players
        tf_envs = [None] * players
        multi_sm = MultiStateMachine(players, protocol)
        for i in range(players):
            envs[i] = MathEnv(players, i, multi_sm, conf=None)
            tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])
    # traverse all possible input space

    for i in list(combs):
        reset_envs(envs, players, input=i)
        time_steps = [None] * players
        next_time_steps = [None] * players
        for p_idx in range(players):
            time_steps[p_idx] = tf_envs[p_idx].reset()
        if is_log:
            print("----------------")
        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            for p_idx in range(players):
                action_step = policy.action(time_steps[p_idx])
                action_steps.append(action_step)
                actions.append(int(action_step.action))
            # ONLY need to print one combination becasue they are all same
            if is_log:
                # print(policy.wrapped_policy._q_network(time_steps[p_idx].observation))
                print(
                        f"{','.join(map(str, time_steps[0].observation.numpy().squeeze()))}:{int(action_step.action)}")
                # print(f"{list(map(lambda x : meaning_mapping[int(x[1])] if x[0] != 0 else x[1], enumerate(time_steps[p_idx].observation.numpy().squeeze())))} -> "
                #       f"{meaning_mapping[int(action_step.action)]}")
            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players)
            for p_idx in range(players):
                next_time_steps[p_idx] = tf_envs[p_idx].step(action_steps[p_idx])

            time_steps = next_time_steps[:]

        # check final time_steps to check if the model is correct
        for ts in time_steps:
            if ts.reward < 0:
                if not is_testall:
                    return False, None
                else:
                    ret = False
                    failed_cnt += 1
                    failed_cases.add(i)
                    break

        if is_testall and failed_cnt == failed_thresh:
            break

    if is_testall:
        assert (failed_cnt == len(failed_cases))
    return ret, failed_cases

if __name__ == "__main__":
    opt = parse_arg()
    print(opt)

    file = opt.load_dir
    policy = restore_policy(file)
    s = time.time()
    ret = informal_verify(policy, None, None, range(2**opt.players), opt.players, opt.protocol, True, -1, True)
    e = time.time()
    print(f"elapsed time: {e-s} seconds")
    print(ret)