from models import random_var
from env_rob_benchmarking import eval_agent_on_mdp, load_agent_api

from a2c_ppo_acktr.model import Policy

import os
import sys
import platform
import shutil
import argparse
import torch
import numpy as np

from overcooked_ai_py.mdp.actions import Action, Direction


def stay_target(all_grads: np.ndarray, choices_diffs, top_k=10):
    target_id = 4  # stay
    assert len(all_grads.shape) == 5  # n_states * n_actions * input_channels * map_size * map_size

    # normalized_grad = all_grads.transpose([0, 2, 3, 4, 1]).copy()
    # mu = np.mean(normalized_grad, axis=-1)
    # sigma = np.std(normalized_grad, axis=-1)
    # # sigma = np.expand_dims(np.std(normalized_grad, axis=-1), axis=-1). repeat(6, axis=-1)
    # normalized_grad = normalized_grad.transpose([4, 0, 1, 2, 3])
    # normalized_grad = (normalized_grad - mu) * np.reciprocal(sigma)
    # target_grads = np.sum(normalized_grad, axis=1)[target_id]

    normalized_grad = all_grads
    target_grads = np.sum(normalized_grad, axis=0)[target_id]

    grad_vs = [np.sum(np.multiply(diff, target_grads)) for diff in choices_diffs]
    grad_vs = np.array(grad_vs)
    choice_rank = np.argsort(-grad_vs)
    return choice_rank[:top_k]


def no_target(all_grads: np.ndarray, choices_diffs, act_probs, top_k=10, debug=False):
    assert len(all_grads.shape) == 5  # n_states * n_actions * input_channels * map_size * map_size
    best_action = np.argmax(act_probs, axis=-1).flatten()
    # grad_idx = np.array([(action_idx, i) for i, action_idx in enumerate(best_action)])

    # target_grads = all_grads.transpose([1, 0, 2, 3, 4])
    target_grads = all_grads[range(all_grads.shape[0]), best_action]
    target_grads = np.sum(target_grads, axis=0)

    grad_vs = [np.sum(np.multiply(diff, target_grads)) for diff in choices_diffs]
    grad_vs = np.array(grad_vs)
    # choice_rank = np.argsort(-grad_vs)
    choice_rank = np.argsort(grad_vs)
    if not debug:
        return choice_rank[:top_k]
    else:
        return choice_rank, grad_vs


