
from atomic_commit import AtomicCommit
from atomic_commit.AtomicCommitEnv import AtomicCommitEnv
from distributed_locking import DistributedLocking
from distributed_locking.DistributedLockingEnv import DistributedLockingEnv
from math_func import Math
from math_func.MathEnv import MathEnv
from primary_backup import PrimaryBackup
from primary_backup.PrimaryBackupEnv import PrimaryBackupEnv


class_map = {
    "math": (MathEnv, Math),
    "distributed_locking": (DistributedLockingEnv, DistributedLocking),
    "primary_backup": (PrimaryBackupEnv, PrimaryBackup),
    "atomic_commit": (AtomicCommitEnv, AtomicCommit),
}

def import_driver(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.AtomicCommitSM_driver import collect_data
    if protocol == "distributed_locking":
        from distributed_locking.DistributedLocking_driver import collect_data
    if protocol == "counter":
        from simple_counter.Counter_driver import collect_data
    if protocol == "math":
        from math_func.Math_driver import collect_data
    if protocol == "primary_backup":
        from primary_backup.PrimaryBackup_driver import collect_data
    return collect_data

def import_reset_envs(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.AtomicCommit import reset_envs
    if protocol == "distributed_locking":
        from distributed_locking.DistributedLocking import reset_envs
    if protocol == "math":
        from math_func.Math import reset_envs
    if protocol == "primary_backup":
        from primary_backup.PrimaryBackup import reset_envs
    return reset_envs

def import_step_envs(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.AtomicCommit import step_envs
    if protocol == "distributed_locking":
        from distributed_locking.DistributedLocking import step_envs
    if protocol == "math":
        from math_func.Math import step_envs
    if protocol == "primary_backup":
        from primary_backup.PrimaryBackup import step_envs
    return step_envs

def import_states(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.State import State
    if protocol == "distributed_locking":
        from distributed_locking.State import State
    if protocol == "primary_backup":
        from primary_backup.State import State

    return State

def get_action_space(protocol):
    if protocol == "primary_backup":
        return 4
    elif protocol == "distributed_locking":
        return 2
    elif protocol == "atomic_commit":
        return 4
    else:
        raise ValueError(f"Protocol {protocol} not supported")

def import_verifier(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.eval import verify as verify
    if protocol == "distributed_locking":
        from distributed_locking.eval import informal_verify as verify
    if protocol == "counter":
        from simple_counter.eval import informal_verify as verify
    if protocol == "math":
        from math_func.eval import informal_verify as verify
    return verify


def print_states(protocol):
    if protocol == "atomic_commit":
        from atomic_commit.AtomicCommitEnv import State
    if protocol == "distributed_locking":
        from distributed_locking.DistributedLockingEnv import State
    if protocol == "counter":
        from simple_counter.ComplexCounterEnv import State
    if protocol == "math":
        from math_func.MathEnv import State
    if protocol == "primary_backup":
        from primary_backup.State import State

    for state in State:
        print(f"{state.name}: {state.value}")
