import multiprocessing
import time
from Enviroment import EnvManager
from import_functions import get_action_space, import_states
from model import AlphaNet
from distributed_locking.DistributedLockingEnv import DistributedLockingEnv
from distributed_locking import DistributedLocking
from math_func.MathEnv import MathEnv
from math_func import Math
from primary_backup.PrimaryBackupEnv import PrimaryBackupEnv
from primary_backup import PrimaryBackup
from config import CUTOF_INDEX
from utils import hash_observation, log, divide_list
from atomic_commit.AtomicCommitEnv import AtomicCommitEnv
from atomic_commit import AtomicCommit
import primary_backup.State as PBState

import numpy as np
from multiprocessing import Pool
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # INFO and WARNING messages are not printed
import tensorflow as tf

tf.get_logger().setLevel("ERROR")


class_map = {
    "math": (MathEnv, Math),
    "distributed_locking": (DistributedLockingEnv, DistributedLocking),
    "primary_backup": (PrimaryBackupEnv, PrimaryBackup),
    "atomic_commit": (AtomicCommitEnv, AtomicCommit),
}

os.environ["TF_NUM_INTRAOP_THREADS"] = "1"
os.environ["TF_NUM_INTEROP_THREADS"] = "1"


# Get gt action for primary backup at this point. TODO: Add more protocols
def gt_action(obs, round):
    obs = obs[1:]
    if round == 0:
        # Consists of LocalOne, LocalZero, and crash
        if PBState.State.LocalOne.value in obs and PBState.State.LocalZero.value in obs:
            return int(PBState.State.DoNothing_One.value)
        # Consists of LocalZero and crash
        elif PBState.State.LocalOne.value in obs:
            return int(PBState.State.DoNothing_One.value)
        elif PBState.State.LocalZero.value in obs:
            return int(PBState.State.DoNothing_Zero.value)
    if round == 1:
        if PBState.State.One.value in obs:
            return int(PBState.State.One.value)
        elif PBState.State.Zero.value in obs:
            return int(PBState.State.Zero.value)
        elif PBState.State.DoNothing_One.value in obs:
            return int(PBState.State.DoNothing_One.value)
        else:
            return int(PBState.State.DoNothing_Zero.value)
    else:
        if PBState.State.DoNothing_One.value in obs or PBState.State.One.value in obs:
            return int(PBState.State.One.value)
        elif PBState.State.DoNothing_Zero.value in obs or PBState.State.Zero.value in obs:
            return int(PBState.State.Zero.value)
        else:
            return int(PBState.State.DoNothing_One.value)


def softmax(x):
    # Subtract the max value for numerical stability
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


def take_action(obs, model, mapping, cur_round, num_round, verbose=False):
    """
    IMPORANT: Have to use model.policy_model() instead of model.policy_model.predict()
    Otherwise, it will cause a deadlock in multiprocessing.
    Reason could be that there is an internal tensorflow thread that is not released.

    Use spwan instead of fork will not have this issue, but it is slower.
    """
    tf_obs = tf.expand_dims(obs, 0)
    policies = softmax(np.squeeze(model.policy_model(tf_obs).numpy())).tolist()
    if cur_round == num_round - 1:  # Last round
        index = np.argmax(np.array(policies))
    else:  # The rounds before the last round
        index = np.argmax(np.array(policies[CUTOF_INDEX:])) + CUTOF_INDEX
    if verbose:
        update_map(obs, index, policies, mapping, cur_round)
    return index


def update_map(state, action, policies, mapping, round):
    obs_list = hash_observation(state)
    if obs_list not in mapping[round]:
        mapping[round][obs_list] = (policies, action)


def get_action(state, cache, round: int):
    hashed_state = hash_observation(state)
    if hashed_state in cache[round]:
        return cache[round][hashed_state][1]
    return None


def launch_verify(
    p_id,
    q_network,
    model_type,
    input_size,
    freezer,
    all_combs,
    combs_key,
    players,
    num_round,
    protocol,
    history,
    encode_id,
    shared_action_cache,
    is_testall=True,
    is_log=False,
    verbose=False,
    is_gt=False,
):
    # is_log: used to print out transition lists
    ###
    # verbose: track the observation-action mapping and the q values of each action
    # (TODO: probablly can combine with is_log flag in the future)
    ###
    # is_gt: ground truth flag, if True, the model will always return the human defined action (no nerual network involved)

    state_action_cache = {}
    for i in range(num_round):
        state_action_cache[i] = {}

    this_func = f"Launch_verify: process id: {p_id}"
    StateClass = import_states(protocol)
    env_mgr = EnvManager(players, protocol, num_round, history, encode_id)

    curr_model = AlphaNet(input_size, get_action_space(protocol), num_round, model_type)

    if not is_gt:
        curr_model.load(q_network)

    failed_cnt = 0
    failed = []
    # traverse all possible input space
    for key in combs_key:
        input_setting = all_combs[key]
        env_mgr.init(input_setting)
        while not env_mgr.is_done():
            zero_based_round = env_mgr.get_zero_based_round()
            crash_info_now = input_setting.get_crash_info(zero_based_round)
            new_crashed = crash_info_now.crash if crash_info_now else []
            actions = []
            for p_idx in range(players):
                action = get_action(env_mgr.get_states(p_idx), state_action_cache, zero_based_round)
                if action is not None:
                    update_map(env_mgr.get_states(p_idx), action, None, state_action_cache, zero_based_round)
                # if the node crashed

                if action is None:
                    if is_gt:
                        action = gt_action(env_mgr.get_states(p_idx), zero_based_round)
                        update_map(env_mgr.get_states(p_idx), action, None, state_action_cache, zero_based_round)
                    else:
                        action = take_action(
                            env_mgr.get_states(p_idx), curr_model, state_action_cache, zero_based_round, num_round, verbose
                        )
                actions.append(action)
            log(this_func, 2, f"action selected: {actions}")
            env_mgr.step(actions, input_setting.get_crash_info(zero_based_round + 1))

        # check final rewards to check if the model is correct
        rewards = env_mgr.get_rewards()
        for r in rewards:
            if r < 0:
                if not is_testall:
                    return [key]
                else:
                    failed_cnt += 1
                    failed.append(key)
                    break

    if is_testall:
        assert failed_cnt == len(failed)

    if verbose:
        for round in state_action_cache:
            print(f"Round {round} Observation-Action Mapping:")
            for obs in state_action_cache[round]:
                print(f"{obs}: {state_action_cache[round][obs]}")

    if is_log:
        for round in state_action_cache:
            if obs not in shared_action_cache[round]:
                shared_action_cache[round][obs] = state_action_cache[round][obs]

    return failed


class Verifier:
    def __init__(self, num_process, all_combs, freezer):
        self.num_process = num_process
        self.pool = Pool(num_process)
        self.all_combs = all_combs
        self.freezer = freezer

    def finish(self):
        self.pool.close()
        self.pool.join()


def sequential_verify(opt, q_network, input_size, verifier, all_combs, combs_id, num_process, is_testall, is_log, is_verbose):
    this_func = "sequential_verify"
    failed = launch_verify(
        0,
        q_network,
        opt.model_type,
        input_size,
        verifier.freezer,
        all_combs,
        combs_id,
        opt.players,
        opt.rounds,
        opt.protocol,
        opt.history,
        opt.encode_id,
        is_testall,
        is_log,
        is_verbose,
    )
    return len(failed) == 0, failed


def parallel_verify(
    opt, q_network, input_size, verifier, all_combs, combs_id, num_process, is_testall, is_log, is_verbose, is_gt=False
):
    this_func = "parallel_verify"
    failed = []
    divided_id = divide_list(combs_id, num_process)

    manager = multiprocessing.Manager()
    shared_action_cache = manager.dict()

    results = [
        verifier.pool.apply_async(
            launch_verify,
            (
                i,
                q_network,
                opt.model_type,
                input_size,
                verifier.freezer,
                all_combs,
                divided_id[i],
                opt.players,
                opt.rounds,
                opt.protocol,
                opt.history,
                opt.encode_id,
                shared_action_cache,
                is_testall,
                is_log,
                is_verbose,
                is_gt,
            ),
        )
        for i in range(num_process)
    ]
    for r in results:
        failed.extend(r.get())

    if is_log:
        for obs in shared_action_cache:
            formatted_obs = obs.replace("[", "").replace("]", "").replace(" ", ",")
            print(f"{formatted_obs}:{shared_action_cache[obs]}")

    return len(failed) == 0, failed


def informal_verify(
    opt, q_network, input_size, verifier, all_combs, combs_id, num_process, is_testall, is_log, is_verbose=False, is_gt=False
):
    return parallel_verify(
        opt, q_network, input_size, verifier, all_combs, combs_id, num_process, is_testall, is_log, is_verbose, is_gt
    )
