import logging

logging.getLogger("tensorflow").setLevel(logging.ERROR)
import os
import itertools
import numpy as np

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from MCTS.model import AlphaNet
from MCTS.util import get_input_dim
from config import CUTOF_INDEX
from import_functions import get_action_space
from utils import parse_arg
from primary_backup.State import State as PBState


def generate_pb_transition(players, model: AlphaNet, protocol, num_rounds, output_file):
    f = open(output_file, "w")
    print(f"Generating transitions, and writing to {output_file} ...")
    action_space = get_action_space(opt.protocol)
    all_intiial_states = [PBState.LocalOne.value, PBState.LocalZero.value]
    all_actions = [action.value for action in list(PBState)[:action_space]]
    # Possible action after each round
    actions_per_round = {i: all_actions[CUTOF_INDEX:] if i < num_rounds - 1 else all_actions for i in range(num_rounds)}
    for r in range(num_rounds):
        # Generate all possible states that every round can appear
        if r == 0:
            # can only consists of initial states and lost
            all_input = list(itertools.product(all_intiial_states + [PBState.Lost.value], repeat=players))
        else:
            # can consists of actions_per_round[i] with lost
            all_input = list(itertools.product(actions_per_round[r - 1] + [PBState.Lost.value], repeat=players))
            pass

        for input_state in all_input:
            state = [r] + list(input_state)
            policy_output = np.squeeze(model.predict_policy(state))
            if r == num_rounds - 1:  # Last round
                index = np.argmax(policy_output)
            else:  # The rounds before the last round
                index = np.argmax(policy_output[CUTOF_INDEX:]) + CUTOF_INDEX
            output_str = ",".join(map(str, state)) + ":" + str(index)
            f.write(output_str + "\n")

    f.close()
    print(f"Transitions generated and saved to {output_file}")


if __name__ == "__main__":
    opt = parse_arg()
    print(opt)

    input_dim = get_input_dim(opt)
    curr_model = AlphaNet(input_dim, get_action_space(opt.protocol), opt.rounds, opt.model_type)
    curr_model.load(opt.load_dir)

    # Construct output file name
    model_name = os.path.basename(os.path.normpath(opt.load_dir))
    output_file = model_name + "_transitions.txt"
    generate_pb_transition(opt.players, curr_model, opt.protocol, opt.rounds, output_file)
