import numpy as np
import random
import pickle

import torch

from copy import deepcopy

from scipy.sparse.csgraph import johnson # for computing shortest paths
from sklearn.cluster import SpectralClustering

import matplotlib.pyplot as plt

# convert an environment to a state represented as a vector
# Input: env - environment object
# Output: state - numpy vector that stores state
def env_to_state(env):
    state = env.gen_obs()['image']
    (m,n,p) = state.shape
    return  state.reshape((m*n*p,))

# Each tile is encoded as a 3 dimensional tuple: (OBJECT_IDX, COLOR_IDX, STATE)

# env.action_space is
#  [0, 0]   - stay
#  [0, 1]   - right
#  [-1, 0]  - up
#  [0, -1]  - left
#  [1, 0]]  - down

# get probabilities of different actions in given environment env with policy ac
def get_probs(env, ac):
    preprocessed_obss = ac.preprocess_obss([env.gen_obs()], device=ac.device)

    with torch.no_grad():
        if ac.acmodel.recurrent:
            dist, _, ac.memories = ac.acmodel(preprocessed_obss, ac.memories)
        else:
            dist, _ = ac.acmodel(preprocessed_obss)

    probs = np.array(dist.probs.tolist())

    return probs

# create list of environments that can be reached from env by taking action
# NOTE: if there is an adversary, the returned list can be several environments to
#       account for the adversary moving after taking action
# Input:
# env - environment from which to start
# action - index of action to take
# Output:
# env_list - list of new environments after taking action
def get_environment_from_action(env, action):
    env2 = deepcopy(env) # copy the environment
    obs, reward, done, _ = env2.step(action) # this step takes the action
    env_list = [env2]
    return env_list

# Remove subgoals based on custom rules
# Input:
# subgoals - list of subgoals indices
# env_list - list of unique environments
# Output:
# subgoals_remaining - remaining list of subgoal indices
# NOTE: If there are no custom rules to remove subgoals,
#       set subgoals_remaining = subgoals
def remove_subgoals_custom_rules(subgoals, env_list):
    subgoals_remaining = subgoals
    return subgoals_remaining

# Remove states from subgoal_path based on custom rules
# Input:
# subgoal_path - list indices of a path to subgoal (last index in path is the subgoal)
# env_list - list of unique environments
# params - dictionary of any extra parameters
# Output:
# subgoal_path_ret - remaining indices in path to subgoal
# NOTE: If there are no custom rules to remove subgoals,
#       set subgoal_path_ret = subgoal_path
def curate_subgoal_path_custom_rules(subgoal_path, env_list, params={}):
    subgoal_path_ret = subgoal_path
    return subgoal_path_ret
