import numpy as np
import random
import pickle

import torch
import torch.nn.functional as F
import torch.autograd as autograd

from common.multiprocessing_env import SubprocVecEnv
from common.minipacman import MiniPacman
from common.actor_critic import ActorCritic, RolloutStorage
from common.deepmind import update_2d_pos
from copy import deepcopy

# convert an environment to a state represented as a vector
# Input: env - environment object
# Output: state - numpy vector that stores state
def env_to_state(env):
    food = env.world_state['food']
    pacman_position = env.world_state['pillman']['pos']
    if len(env.world_state['ghosts']) > 0: # ghost exists
        ghost_position = env.world_state['ghosts'][0]['pos']
    else:
        ghost_position = np.array([-1, -1]) # ghost is not on board
    if len(env.world_state['pills']) > 0: # pill exists
        pill_position = env.world_state['pills'][0]['pos']
    else:
        pill_position = np.array([-1, -1]) # pill is not on board

    # concatenate the parts of state
    (m,n) = food.shape
    state = np.concatenate((food.reshape((m*n,)), pacman_position, ghost_position, pill_position))

    return state

# env.action_space is
#  [0, 0]   - stay
#  [0, 1]   - right
#  [-1, 0]  - up
#  [0, -1]  - left
#  [1, 0]]  - down

# get probabilities of different actions in given environment env with policy ac
def get_probs(env, ac):
    (_,_,state) = env.observation()
    state = state.transpose(2, 0, 1)

    USE_CUDA = torch.cuda.is_available()
    Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
    state = torch.FloatTensor(state).unsqueeze(0)
    if USE_CUDA:
        state = state.cuda()
    logit, value = ac.forward(state)
    probs = F.softmax(logit)

    return probs

# create list of environments that can be reached from env by taking action
# NOTE: if there is an adversary, the returned list can be several environments to
#       account for the adversary moving after taking action
# Input:
# env - environment from which to start
# action - index of action to take
# Output:
# env_list - list of new environments after taking action
def get_environment_from_action(env, action):
    env_list = []
    env._move_pillman(action) # this step moves the pillman
    available_ghost_actions = env._ghost_available_actions(env.world_state['ghosts'][0])
    for ghost_action in available_ghost_actions:
        env2 = deepcopy(env) # copy the envirnment for each ghost most
        pos = env2.world_state['ghosts'][0]['pos'] # get the ghost's position
        update_2d_pos(env2.map, pos, ghost_action, pos) # update the ghost's position when taking ghost_action
        env2.world_state['ghosts'][0]['dir'] = ghost_action
        env2._make_image()
        env_list.append(env2)
    return env_list

# Remove subgoals based on custom rules
# Input:
# subgoals - list of subgoals indices
# env_list - list of unique environments
# Output:
# subgoals_remaining - remaining list of subgoal indices
# NOTE: If there are no custom rules to remove subgoals,
#       set subgoals_remaining = subgoals
def remove_subgoals_custom_rules(subgoals, env_list):
    # choose unique subgoals according to pacman position and food eaten (by removing ghost position from state)
    subgoal_states = [] # subgoal_states is list of subgoals represented without ghost position
    for j in subgoals:
        state = env_to_state(env_list[j])
        state[-4:-2] = np.array([-1, -1]) # removes ghost position
        subgoal_states.append(state)
    subgoals_remaining = []
    for j in range(len(subgoals)):
        add_subgoal = True
        for k in range(j):
            if np.sum(np.abs(subgoal_states[k]-subgoal_states[j])) == 0 : # subgoal has same pacman/pill position as previous goal
                add_subgoal = False
        if add_subgoal:
            subgoals_remaining.append(subgoals[j])
    return subgoals_remaining

# Remove states from subgoal_path based on custom rules
# Input:
# subgoal_path - list indices of a path to subgoal (last index in path is the subgoal)
# env_list - list of unique environments
# params - dictionary of any extra parameters
# Output:
# subgoal_path_ret - remaining indices in path to subgoal
# NOTE: If there are no custom rules to remove subgoals,
#       set subgoal_path_ret = subgoal_path
def curate_subgoal_path_custom_rules(subgoal_path, env_list, params={}):
    min_moves_from_pacman = 3
    n = len(subgoal_path)
    subgoal = subgoal_path[n-1]

    # remove states where pacman is in same position as subgoal
    pacman_pos = env_list[subgoal].world_state['pillman']['pos']
    subgoal_path_remaining = subgoal_path.copy()
    distances = []
    for state in subgoal_path:
        if state != subgoal: # do not remove the subgoal
            pos = env_list[state].world_state['pillman']['pos']
            if pos[0] == pacman_pos[0] and pos[1] == pacman_pos[1]: # check if same position as subgoal pacman
                subgoal_path_remaining.remove(state)
            else:
                distances.append(np.abs(pos[0]-pacman_pos[0])+np.abs(pos[1]-pacman_pos[1]))

    # path is only valid if pacman moves at least min_moves_from_pacman moves within cluster to get to subgoal
    subgoal_path_ret = []
    if len(distances) > 0: # there are still remaining subgoals to check distances on:
        if max(distances) >= params['min_moves_from_pacman']:
            subgoal_path_ret = subgoal_path_remaining
    return subgoal_path_ret
