import numpy as np
import tensorflow as tf


def compute_expectation(track_dict):
    total_cnt = 0
    track_dict = dict(sorted(track_dict.items()))
    mean_scores = []
    for action in track_dict.keys():
        sum = 0
        action_cnt = 0
        for reward in track_dict[action].keys():
            sum += reward * track_dict[action][reward]
            action_cnt += track_dict[action][reward]
        total_cnt += action_cnt
        print(
            f"action: {action}  (cnt: {action_cnt}) sum: {sum:.4f}, mean: {sum / action_cnt:.4f}"
        )
        mean_scores.append(sum / action_cnt)
    print(f"total cnt: {total_cnt}")


def update_dicts(exp, arrays, track_dict):
    actions, rewards = extract_exp(exp, arrays)
    for a, c in zip(actions, rewards):
        if a in track_dict.keys():
            if c in track_dict[a].keys():
                track_dict[a][c] += 1
            else:
                track_dict[a][c] = 1
        else:
            track_dict[a] = dict()
            track_dict[a][c] = 1


def extract_exp(exp, arrays):
    exp_np = exp.observation.numpy()
    condition = [np.all(exp_np == array, axis=2) for array in arrays]
    index = np.where(np.logical_or.reduce(condition))
    index = index[0][np.where(index[1] == 0)]
    round = exp.step_type.numpy()[index, 0]

    if len(round) == 0:
        return [], []
    else:
        round = round[0]

    rewards = exp.reward.numpy()[index, 0]
    actions = exp.action.numpy()[index, 0]
    return actions, rewards


def track_one_state(state, py_env, tf_env, agent):
    policy = agent.policy
    py_env.states = state
    time_step = tf_env.reset()
    action_step = policy.action(time_step)
    print(agent._q_network(time_step.observation)[0].numpy())
    print(f"{state} ---> {int(action_step.action)}")
