import os

from helper_functions import get_environments, get_unique_environments, create_transition_matrix, shortest_path_ssx
from helper_functions import compute_numpaths_count, check_coverage, get_subgoals, cluster_ssx

import utils

import numpy as np
import random
import pickle

import torch

from copy import deepcopy

from scipy.sparse.csgraph import johnson # for computing shortest paths
from sklearn.cluster import SpectralClustering

import matplotlib.pyplot as plt
# Inputs:
# filename_states - filename for boards (assumed in saved_boards directory)
# filename_model - filename for policy (assmed in models directory)
# steps - number of steps from initial board from which to intepret the policy
# max_steps - maximum number of steps to use for computing the local state space
# num_clusters - number of metastates to create
# num_subgoals - number of subgoals to learn per metastate
# rho - penalty on tradeoff between number of paths through a state and diversity with other subgoals
# use_count_subgoals - True if to measure subgoals by numpaths_count and False if to measure subgoals by numpaths_prob
# eps_add_goal - percent change required in objective to add a new subgoal
# cluster_method - 1 for cluster_ssx, 2 for sklearn.cluster.SpectralClustering
# eta - regularization parameter for cluster_ssx
# projection_type - spectral projection type for cluster_ssx, 1 if L = D^{-1/2}*A*D^{-1/2} (use largest eigenvectors), 2 if L = D^{-1}*A (use largest eigenvectors), 3 if L = D - A (use smallest eigenvectors)
# Outputs:
# subgoals_list - list of lists where subgoals_list[i] are the subgoal states for cluster i
# shortest_dist - shortest_dist[i, j] is shortest distance between states i and j
# predecessors - predecessors[i, j] is predecessor of state j on path from state i to state j
# transition_matrix - one step transition matrix of states based on input policy
# env_list_final - list of environments local to the input environment
# state_list_final - list of grid worlds (i.e., states) local to the input environment
# number_steps_final - list of number of steps to each environment in env_list_final from initial environment
# end_states_list - list of states from which there are no reachable local environments
# labels_spectral - numpy array where labels_spectral[i] is the cluster id of state i
# diversity - diversity[i,j] is defined as minimum( shortest_dist[i, j], shortest_dist[j, i])
def doorkey_run_flow(filename_states, filename_model, steps, max_steps, num_clusters, num_subgoals, rho, use_count_subgoals, eps_add_goal, cluster_method=1, eta=.0001, projection_type=1):
    argmax = False
    memory = False
    text = False

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n")

    loadfile = open(os.path.join('saved_boards', filename_states), 'rb')
    env_steps = pickle.load(loadfile)
    env = env_steps[steps]

    # Load agent
    model_dir = utils.get_model_dir(filename_model)
    actor_critic = utils.Agent(env.observation_space, env.action_space, model_dir,
                        device=device, argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")

    random.seed(121)
    np.random.seed(1211)

    # create a list of environments for all possible moves of max length max_steps from the original board
    action_prob_threshold = 0.00001
    (env_list, number_steps_list) = get_environments(env, actor_critic, max_steps, action_prob_threshold)

    # create unique list of environments
    (env_list_final, state_list_final, number_steps_final) = get_unique_environments(env_list, number_steps_list)
    num_states = len(env_list_final)

    # get index of initial state
    ind_initial = np.where(np.array(number_steps_final)==0)[0][0]

    # create transition matrix
    transition_matrix = create_transition_matrix(env_list_final, state_list_final, actor_critic)

    # create shortest distance and predecessor matrices
    (shortest_dist, predecessors) = shortest_path_ssx(transition_matrix)

    # Spectral Clustering
    gamma = .9
    eps_clustering = .001
    max_iters_clustering = 10
    projection_type = 1 # there are different types of spectral projections to choose from
    if cluster_method == 1:
        labels_spectral = cluster_ssx(predecessors, shortest_dist, num_states, num_clusters, gamma, eta, projection_type, eps_clustering, max_iters_clustering)
    else:
        d = np.exp(-gamma * shortest_dist)
        clustering = SpectralClustering(n_clusters=num_clusters,assign_labels="kmeans",random_state=0, affinity='precomputed').fit(d)
        labels_spectral = clustering.labels_ # labels_spectral are the metastates

    # Compute number of paths (and their -log likelihoods) through nodes that change clusters
    (numpaths_count, numpaths_prob, numpaths_count_pred) = compute_numpaths_count(predecessors, labels_spectral, shortest_dist, num_states, num_clusters)

    # Compute diversity measure of nodes i and j as shortest distance between them
    # Greedy method to choose m subgoals for cluster
    # diversity[i,j] is the shortest distance between states i and j
    diversity = np.exp(-shortest_dist.copy()) # transform shortest distance to maximum probability path
    for i in range(num_states):
        for j in range(num_states):
            diversity[i,j] = np.maximum(diversity[i,j], diversity[j,i])

    print('Initial State is ' + str(ind_initial) + ' in Cluster '+ str(labels_spectral[ind_initial]))
    subgoals_list = get_subgoals(numpaths_count, numpaths_prob, labels_spectral, diversity, num_subgoals, rho, num_states, num_clusters, use_count_subgoals, eps_add_goal)

    # find indices of end states (i.e., states that took max_steps to get to
    end_states=[]
    end_states_list = [[] for i in range(num_clusters)]
    for i in range(num_states):
        if len(np.where(transition_matrix[i,:])[0]) == 0:
            end_states.append(i)
            end_states_list[labels_spectral[i]].append(i)

#    check_coverage(state_list_final, labels_spectral, subgoals_list, end_states_list, predecessors)

    return (subgoals_list, shortest_dist, predecessors, transition_matrix, env_list_final, state_list_final, number_steps_final, end_states_list, labels_spectral, diversity, ind_initial)