from coordinator.CoordinatorEnv import Stat
from utils import generate_arrays


def extensive_evaluate(tf_env, py_env, policy, players):
    all_settings = generate_arrays(players)
    failed_case = 0
    for setting in all_settings:
        print(setting)
        init_setting = setting[:]
        if float(Stat.Abort.value) not in setting and float(Stat.Lost.value) in setting:
            idx = setting.index(float(Stat.Lost.value))
            init_setting[idx] = float(Stat.Abort.value)
            '''
            IMPORTANT: only set one initial state to be Abort!
            The rest may have Lost value digit, which should not appear in initial states,
            but it's okay at this moment
            '''
        has_crash = float(Stat.Lost.value) in setting
        print(has_crash)
        py_env.my_reset(init_setting, setting, has_crash)
        ret = compute_one_episode(tf_env, py_env, policy)
        if not ret:
            failed_case += 1

    print(f"FAILED CASE: {failed_case}")

def compute_one_episode(tf_env, py_env, policy):
    time_step = tf_env.reset()
    while not time_step.is_last():
        print(time_step)
        action_step = policy.action(time_step)
        print(action_step.action)
        next_time_step = tf_env.step(action_step.action)
        print(next_time_step.reward)
        if next_time_step.reward < 0:
            print(time_step.observation)
            print(py_env._init_states)
            print(action_step)
            return False
        time_step = next_time_step
    return True

def random_evaluate(environment, policy, agent, num_episodes=10):
    failed_case = 0
    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():

            action_step = policy.action(time_step)
            next_time_step = environment.step(action_step.action)

            episode_return += time_step.reward
            # if np.all(time_step.observation.numpy() == 2):
            #     print(action_step)
            #     print(time_step)
            if next_time_step.reward < 0:
                print(time_step.observation)
                q_values, _ = agent._q_network(time_step.observation, step_type=time_step.step_type)
                print(q_values)
                print(action_step)
                failed_case += 1
            time_step = next_time_step

        total_return += episode_return


    print(f"FAILED: {failed_case}")
    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]