import sys
import os
import time

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from typing import List

# Add the project root directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
print(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
import numpy as np
from multiprocessing import Pool
from import_functions import get_action_space, import_states
from Enviroment import EnvManager
from MCTS import MCTS
from fix_info import *
from fix_logic import select_candidate_to_fix, store_candidates, unfix
from util import get_input_dim, softmax
from model import AlphaNet
from utils import log, parse_arg
from action_freezer import ActionFreezer
from informal_verifier import Verifier, informal_verify
from StateActionTracker import StateActionStack
from sample_settings import sample_settings_from_combs
from filter_combs import filter_combs, filter_combs_by_initial_and_round, filter_combs_by_initial_states
from generate_combs import Info, generate_all_combs, generate_all_combs_nocrash

# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# tf.get_logger().setLevel("ERROR")

np.set_printoptions(precision=3, suppress=True)

timing_info = {"mcts": 0, "train": 0, "validate": 0}


def prepare_train_data(
    model, value_buffer, policy_buffer, actions, states, policies, rewards, lost_value, is_greedy=False, debug=False
):
    func = "prepare_train_data"
    assert len(states) == len(policies)
    assert len(states[0]) == len(actions[0])
    for round in range(len(states)):
        last_round = round == len(states) - 1
        for i in range(len(states[round])):
            # actions[round][i] == lost_value means this node has crashed at previous rounds
            if actions[round][i] is lost_value:
                continue

            # If the node has crashed at this round, but it's not the last round, we don't consider them
            newly_crash = (not last_round) and (actions[round + 1][i] == lost_value)
            if newly_crash and not last_round:
                continue

            policy_buffer.append((states[round][i], policies[round][i]))
            policy_output = softmax(np.squeeze(model.predict_policy(states[round][i])))
            if debug:
                print(
                    f"Store train sample: state: {states[round][i]}, policy_target: {policies[round][i]}, policy_output: {policy_output}, reward: {rewards[i]}"
                )


# Return whether new state is fixed
def self_play(
    opt, model, dfs_tracker: StateActionStack, value_buffer, policy_buffer, all_combs, subset_combs, failed_id, is_fix=False
):
    func = "self_play"
    players = opt.players
    env_mgr = EnvManager(players, opt.protocol, opt.rounds, opt.history, opt.encode_id)
    candidates = []
    fixed = False  # Whether an action is fixed for one state
    unfixed = False  # Whether an action is unfixed for one state
    StateClass = import_states(opt.protocol)

    for self_play_idx in range(100):
        store_data = True  # Whether to store data in this simulation for training
        setting_id, setting = sample_settings_from_combs(subset_combs, failed_id, all_combs)
        env_mgr.init(setting)
        init_key = env_mgr.construct_crash_key(setting.get_crash_info(0))
        mcts = MCTS(model, players, setting.get_crash_info(0), init_key, opt.protocol, opt.rounds)
        reached_states = []
        states = [[] for _ in range(opt.rounds)]
        policies = [[] for _ in range(opt.rounds)]
        policy_output = [[] for _ in range(opt.rounds)]
        actions = [[] for _ in range(opt.rounds)]
        while not env_mgr.is_done():
            current_round = env_mgr.get_zero_based_round()
            action, pi_target = mcts.execute(env_mgr, dfs_tracker)

            log(func, 1, f"selected move, actions: {action}")
            for i in range(players):
                states[current_round].append(env_mgr.get_states(i))
                policies[current_round].append(pi_target[i])
                policy_output[current_round].append(softmax(np.squeeze(model.predict_policy(env_mgr.get_states(i)))))
                actions[current_round].append(action[i])
            env_mgr.step(action, setting.get_crash_info(current_round + 1))
            if env_mgr.is_done():
                break
            mcts.update_tree(env_mgr, action, setting.get_crash_info(current_round + 1))

        env_mgr.print_current_states()
        if opt.debug:
            mcts.visualize_tree(mcts.root, 0)
        rewards = env_mgr.get_rewards()
        # Fix action for simulation
        if is_fix:
            store_candidates(
                opt.rounds, dfs_tracker, players, states, policies, actions, StateClass.Lost.value, setting_id, candidates
            )
            store_candidates(
                opt.rounds, dfs_tracker, players, states, policy_output, actions, StateClass.Lost.value, setting_id, candidates
            )

        if any(r < 0 for r in rewards):
            unfixed = unfix(opt.rounds, dfs_tracker, players, states, rewards, actions, StateClass.Lost.value)
            if unfixed:
                print("[FIXING LOGIC TRIGGERED]: unfix")

        log(func, 1, f"rewards: {rewards}")
        # if store_data or dfs_tracker.latest_key() is None:
        if store_data or dfs_tracker.is_empty():
            prepare_train_data(
                model,
                value_buffer,
                policy_buffer,
                actions,
                states,
                policies,
                rewards,
                lost_value=StateClass.Lost.value,
                is_greedy=False,
                debug=True,
            )
        del mcts
    if is_fix:
        fixed = select_candidate_to_fix(candidates, dfs_tracker)
        if fixed:
            print("[FIXING LOGIC TRIGGERED]: fix")
    return fixed or unfixed


def evaluate(opt, verifier, input_dim, model_path, eval_combs):
    pass_all, failed_cases = informal_verify(
        opt, model_path, input_dim, verifier, verifier.all_combs, eval_combs, verifier.num_process, is_testall=True, is_log=False
    )
    return pass_all, failed_cases


def main():
    opt = parse_arg()
    print(opt)

    # Init model
    input_dim = get_input_dim(opt)
    curr_model = AlphaNet(input_dim, get_action_space(opt.protocol), opt.rounds, opt.model_type)

    # MCTS simulation sample settings
    sample_round = opt.rounds - 1  # Use this parameter to control the sampling set size. Start with the last round

    # Generate all inputs
    if opt.protocol == "primary_backup" or opt.protocol == "atomic_commit":
        all_combs: List[Info] = generate_all_combs(range(2**opt.players), opt.players, opt.rounds)
    else:
        all_combs: List[Info] = generate_all_combs_nocrash(opt.players, opt.rounds)

    # Init verifier
    freezer = ActionFreezer(opt.protocol, enable_freeze=opt.freeze)
    v = Verifier(16, all_combs, freezer)

    # Init StateActionStack
    dfs_tracker = StateActionStack()
    dfs_tracker.load_from_stack_info(get_fix_info(opt.protocol, opt.players, opt.rounds))
    dfs_tracker.lock_all_states()
    value_buffer = []
    policy_buffer = []
    buffer_size = 2000
    best_failed = len(all_combs)
    iter_idx = 0

    if opt.protocol == "primary_backup":
        filter_initial_states = [0, 2**opt.players - 1]
    if opt.protocol == "atomic_commit":
        filter_initial_states = [2**opt.players - 2, 2**opt.players - 1]
    # Get subset_combs that we sample from all combs, start with combinations with single values
    subset_combs = filter_combs_by_initial_and_round(filter_initial_states, sample_round, opt.players, all_combs)
    # subset_combs = range(len(all_combs))

    # Load existing model
    if opt.load_dir:
        curr_model.load(opt.load_dir)
        _, failed_cases = evaluate(opt, v, input_dim, opt.load_dir, subset_combs)
        failed_cases = subset_combs if len(failed_cases) == 0 else failed_cases
    else:
        failed_cases = subset_combs
    best_failed = len(failed_cases)

    # seed = 19970127
    # random.seed(seed)
    # print(f"random seed: {seed}")

    is_full_set = False
    start = time.time()
    while len(failed_cases) > 0:
        if iter_idx == opt.pre_train:
            policy_buffer.clear()
            value_buffer.clear()
        is_fix = opt.fix and (iter_idx % opt.interval == 0) and iter_idx >= opt.pre_train
        s = time.time()
        ret = self_play(
            opt, curr_model, dfs_tracker, value_buffer, policy_buffer, all_combs, subset_combs, failed_cases, is_fix=is_fix
        )
        e = time.time()
        timing_info["mcts"] += e - s
        print(f"self_play time: {e - s} seconds, iter_idx: {iter_idx}")

        dfs_tracker.print_track_info()

        if ret:
            dfs_tracker.print_stack_info()
            print(f"Save fix model: {opt.save_dir}_{iter_idx}_fix")
            curr_model.save(opt.save_dir + f"_{iter_idx}_fix")

        print(f"Current value buffer size: {len(value_buffer)}, policy buffer size: {len(policy_buffer)}")
        s = time.time()
        curr_model.train(value_buffer, policy_buffer, 5)
        e = time.time()
        timing_info["train"] += e - s
        if len(value_buffer) > buffer_size:
            value_buffer = value_buffer[-int(buffer_size / 2) :]
        if len(policy_buffer) > buffer_size:
            policy_buffer = policy_buffer[-int(buffer_size / 2) :]

        curr_model.save(opt.save_dir)
        print(f"updated model at iter: {iter_idx}, model path: {opt.save_dir}")
        curr_model.clear_map()

        ## Evaluate updated model
        s = time.time()
        # step-1: evaluate failed cases from previous iteration, if passed all then evaluate the rest cases
        pass_all_failed, failed_cases = evaluate(opt, v, input_dim, opt.save_dir, failed_cases)
        if len(failed_cases) <= 20:
            print(f"Failed cases: {failed_cases}")
        # step-2: evaluate all cases (remaining cases)
        if pass_all_failed:
            print("pass all previous failed cases")
            pass_all, failed_cases = evaluate(opt, v, input_dim, opt.save_dir, subset_combs)

            if len(failed_cases) < best_failed:
                best_failed = len(failed_cases)
                print(f"Save best model: {opt.save_dir}_{iter_idx}_best")
                curr_model.save(opt.save_dir + f"_{iter_idx}_best")

            print(f"Iteration: {iter_idx}, failed_cases: {len(failed_cases)}")

            if len(failed_cases) <= 20:
                print(f"Failed cases: {failed_cases}")

            if pass_all:
                print("All cases passed")
                if is_full_set:
                    print("Already evaluated the full set, exit")
                    print(timing_info)
                    end = time.time()
                    print(f"Total time: {end - start} seconds")
                    return
                print(f"Current simulation set is sampled from round: {sample_round}, total cases: {len(subset_combs)}")
                if sample_round > 0:
                    sample_round -= 1
                    subset_combs = filter_combs_by_initial_and_round(filter_initial_states, sample_round, opt.players, all_combs)
                    print("Increase sample size by one round")
                else:
                    print("Increase the sample size to the full set")
                    is_full_set = True
                    subset_combs = range(len(all_combs))
                failed_cases = subset_combs
        else:
            print(f"Iteration: {iter_idx}, remaining: {len(failed_cases)}")
        e = time.time()
        timing_info["validate"] += e - s
        print(f"evaluation time: {e - s} seconds")

        iter_idx += 1


if __name__ == "__main__":
    main()
