from tf_agents.environments import tf_py_environment

from checkpoint import restore_policy
from MultiSM import MultiStateMachine
from atomic_commit.AtomicCommitEnv import AtomicCommitEnv
from config import IS_ENCODED_HISTORY
from atomic_commit.State import State, meanings
from utils import parse_arg
import itertools

round2 = [State.Commit.value, State.Abort.value, State.DoNothingCommit.value, State.Lost_R2.value]
round1 = [State.Lost_R1.value, State.LocalCommit.value, State.LocalAbort.value]
never_seen_list = [
    [1,1,3], [1,3,1], [3,1,1],
    [3,1,3], [3,3,1], [1,3,3],
    [2,2,3], [2,3,2], [3,2,2],
    [3,2,3], [3,3,2], [2,3,3],

]

def generate_pb_transition(players):
    from primary_backup.State import State as PBState
    r1 = [PBState.LocalZero.value, PBState.LocalOne.value, PBState.Lost.value]
    r2 = [PBState.Zero.value, PBState.One.value, PBState.Lost.value, PBState.DoNothing_Zero.value, PBState.DoNothing_One.value]
    round1_comb = list(itertools.product(r1, repeat=players))
    round2_comb = list(itertools.product(r2, repeat=players))
    for comb in round1_comb:
        comb = list(comb)
        if set(comb) <= {PBState.LocalOne.value, PBState.Lost.value}:
            print(f"{','.join(map(str, comb))}:{PBState.DoNothing_One.value}")
        elif set(comb) <= {PBState.LocalZero.value, PBState.Lost.value}:
            print(f"{','.join(map(str, comb))}:{PBState.DoNothing_Zero.value}")
        else:
            print(f"{','.join(map(str, comb))}:{PBState.DoNothing_One.value}")

    for comb in round2_comb:
        comb = list(comb)
        if PBState.One.value in comb or PBState.DoNothing_One.value in comb:
            print(f"{','.join(map(str, comb))}:{PBState.One.value}")
        elif PBState.Zero.value in comb or PBState.DoNothing_Zero.value in comb:
            print(f"{','.join(map(str, comb))}:{PBState.Zero.value}")
        else:
            print(f"{','.join(map(str, comb))}:4")


def traverse_combs(combs, policy, env, tf_env, log=False):
    transitions = []
    state_mapping = {s.value: s.name for s in State}
    for comb in combs:
        comb = list(comb)
        if all(e == 7 for e in comb):#TODO: don't hard code 7 and "Commit"
            transitions.append([list(map(lambda x: state_mapping[x], comb)), "Commit"])
            if log:
                print(f"{','.join(map(str, comb))}:1")
            continue
        env.states = comb
        time_step = tf_env.reset()
        action_step = policy.action(time_step)
        action = int(action_step.action) + 1
        if log:
            print(f"{','.join(map(str, comb))}:{action}")
        transitions.append([list(map(lambda x: state_mapping[x], comb)), state_mapping[action]])
    return transitions

def get_transitions_by_round(players, policy, env, tf_env, log=False):
    transitions = []
    round1_comb = list(itertools.product(round1, repeat=players))
    if IS_ENCODED_HISTORY:
        round1_comb = add_local(round1_comb)
    transitions.extend(traverse_combs(round1_comb, policy, env, tf_env, log))
    round2_comb = list(itertools.product(round2, repeat=players))
    if IS_ENCODED_HISTORY:
        round2_comb = add_local(round2_comb)
    transitions.extend(traverse_combs(round2_comb, policy, env, tf_env, log))

    return transitions

def round1_transition(players, policy, env, tf_env):
    edges = {}
    combinations = list(itertools.product(round1, repeat=players))
    for comb in combinations:
        comb = list(comb)
        env.states = comb
        time_step = tf_env.reset()
        action_step = policy.action(time_step)
        action = int(action_step.action) + 1
        print(f"{','.join(map(str, comb))}:{action}")
        if action in edges:
            edges[action].append(time_step.observation.numpy()[0].tolist())
        else:
            edges[action] = [time_step.observation.numpy()[0].tolist()]

    for key in edges:
        for state in [State.LocalCommit.value, State.LocalAbort.value]:
            print(f"{meanings[state]} --> {meanings[key]}")
            for edge in edges[key]:
                if state in edge:
                    print(edge)

def all_transisions(players, policy, env, tf_env, log=False):
    transitions = []
    state_mapping = {s.value: s.name for s in State}
    edges = {}
    combinations = list(itertools.product([1,2,3,4,5,6,7], repeat=players))
    for comb in combinations:
        comb = list(comb)
        if all(e == 7 for e in comb):#TODO: don't hard code 7 and "Commit"
            transitions.append([list(map(lambda x: state_mapping[x], comb)), "Commit"])
            if log:
                print(f"{','.join(map(str, comb))}:1")
            continue
        env.states = comb
        time_step = tf_env.reset()
        action_step = policy.action(time_step)
        action = int(action_step.action) + 1
        transitions.append([list(map(lambda x: state_mapping[x], comb)), state_mapping[action]])
        if log:
            print(f"{','.join(map(str, comb))}:{action}")
        if action in edges:
            edges[action].append(time_step.observation.numpy()[0].tolist())
        else:
            edges[action] = [time_step.observation.numpy()[0].tolist()]
    return transitions

def get_transitions(players, policy, protocol):
    multi_sm = MultiStateMachine(players, protocol)
    env = AtomicCommitEnv(players, 0, multi_sm, None)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    return get_transitions_by_round(players, policy, env, tf_env, False)

def visualize():
    pass

def add_local(combs):
    new_combs = []
    for comb in combs:
        unique_state = list(set(comb))
        for u in unique_state:
            if u == int(State.Lost_R1.value) or u == int(State.Lost_R2.value):
                continue
            new_comb = list(comb).copy()
            new_comb.insert(0, u)
            new_combs.append(new_comb)
    return new_combs


if __name__ == "__main__":
    opt = parse_arg()
    # if opt.load_dir:
    #     # print(f"LOAD policy from {opt.load_dir}")
    #     policy = restore_policy(opt.load_dir)
    # else:
    #     print("ERROR! no load dir provided")
    #     exit()

    # multi_sm = MultiStateMachine(opt.players, opt.protocol)
    # env = AtomicCommitSMEnv(opt.players, 0, multi_sm, None)
    # tf_env = tf_py_environment.TFPyEnvironment(env)
    # player_idx = 0
    # round1_transition(opt.players, policy, env, tf_env)
    # round2_transition(opt.players, policy, env, tf_env)
    generate_pb_transition(opt.players)