import gc
import sys
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from typing import List

# Add the project root directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from create_crash_info import create_info_manually
from Timer import Timer
import tensorflow as tf
import random
import numpy as np
import multiprocessing
from multiprocessing import Pool
from import_functions import get_action_space, import_states
from Enviroment import EnvManager
from MCTS import MCTS
from fix_info import *
from fix_logic import fix_specific, unfix
from util import get_input_dim, softmax
from model import AlphaNet
from utils import hash_observation, log, parse_arg
from primary_backup.State import State
from action_freezer import ActionFreezer
from informal_verifier import Verifier, informal_verify
from StateActionTracker import StateActionStack
from generate_combs import Info, generate_all_combs, generate_all_combs_nocrash, parse_info_from_string

np.set_printoptions(precision=3, suppress=True)


def main():
    opt = parse_arg()
    print(opt)

    # Init model
    input_dim = get_input_dim(opt)
    curr_model = AlphaNet(input_dim, get_action_space(opt.protocol), opt.rounds, opt.model_type)

    # Generate all inputs
    if opt.protocol == "primary_backup" or opt.protocol == "atomic_commit":
        all_combs: List[Info] = generate_all_combs(range(2**opt.players), opt.players, opt.rounds)
    else:
        all_combs: List[Info] = generate_all_combs_nocrash(range(2**opt.players), opt.players, opt.rounds)

    # Init StateActionStack
    dfs_tracker = StateActionStack()
    dfs_tracker.load_from_stack_info(get_fix_info(opt.protocol, opt.players))
    dfs_tracker.lock_all_states()
    # Load existing model
    if opt.load_dir:
        curr_model.load(opt.load_dir)
    for iter in range(1):
        # random.seed(iter)

        players = opt.players
        env_mgr = EnvManager(players, opt.protocol, opt.rounds, opt.history, opt.encode_id)
        StateClass = import_states(opt.protocol)
        # setting = create_info_manually()
        setting = all_combs[opt.test_id]
        print(setting)
        env_mgr.init(setting)
        init_key = env_mgr.construct_crash_key(setting.get_crash_info(0))
        mcts = MCTS(curr_model, players, setting.get_crash_info(0), init_key, opt.protocol, opt.rounds)
        while not env_mgr.is_done():
            action, pi_target = mcts.execute(env_mgr, dfs_tracker)
            for i in range(players):
                if action[i] == StateClass.Lost.value:
                    continue
                policy_output = softmax(np.squeeze(curr_model.predict_policy(env_mgr.get_states(i))))
                print(f"state: {env_mgr.get_states(i)}, policy_target: {pi_target[i]}, policy_output: {policy_output}")
            next_round = env_mgr.get_zero_based_round() + 1
            crash_info = setting.get_crash_info(next_round)
            env_mgr.step(action, crash_info)
            if env_mgr.is_done():
                break
            mcts.update_tree(env_mgr, action, crash_info)
            break
        env_mgr.print_current_states()
        mcts.visualize_tree(mcts.root, 0)
        rewards = env_mgr.get_rewards()
        print("Final rewards: ", rewards)


if __name__ == "__main__":
    main()
