#!/usr/bin/env python

# Python imports.
import sys, os, random, time, math
import numpy as np


# Other imports.
from simple_rl.tasks import FourRoomMDP
from simple_rl.planning import ValueIteration

from scipy.sparse.csgraph import johnson # for computing shortest paths
from sklearn.cluster import SpectralClustering
import matplotlib.pyplot as plt
import matplotlib

import info_sa
from fourrooms_helper_functions import create_distance_matrix

if __name__ == '__main__':

    # insert seed to fix the expert policy
    seed = 5
    random.seed(seed)
    np.random.seed(seed)

    # Make MDP.
    grid_dim = 11
    # NOTE: we make slip_prob 0.0 here, different from paper
    mdp = FourRoomMDP(width=grid_dim, height=grid_dim, init_loc=(1, 1), slip_prob=0.00, goal_locs=[(grid_dim, grid_dim)], gamma=0.9)

    # For comparing policies and visualizing.
    beta_distance_prob = 0.9 # parameter to create probabilities from q-value(state,action) pairs

    is_deterministic_ib = True
    is_agent_in_control = True

    # Get demo policy.
    vi = ValueIteration(mdp)
    _, val = vi.run_vi()

    distance = create_distance_matrix(vi, mdp.get_actions(), beta_distance_prob)
    [shortest_dist, predecessors] = johnson(distance, return_predecessors = True)

    # Spectral Clustering
    gamma = .1
    d = np.exp(-gamma * shortest_dist)
    num_clusters = 4
    clustering = SpectralClustering(n_clusters=num_clusters,assign_labels="kmeans",random_state=0, affinity='precomputed').fit(d)
    labels_spectral = clustering.labels_
    states = vi.get_states()
    num_states = len(states)

    # Compute number of changes from metastate i to metastate j at each state
    # E.G, number of paths along which state i is last state along path in metastate j
    # before path switches to a new metastate k
    metastate_count = np.zeros((num_states, num_clusters, num_clusters))
    #dist_inf = -np.log(eps)
    for i in range(num_states): # run loop for each state
        predecessors_meta = predecessors[i,:]
        inds = np.where(predecessors_meta >= 0)[0] # there is a path from i to inds
        for j in inds: # go through each node j with path from i to j
            pred_j = predecessors_meta[j] # predecessor of node j
            if labels_spectral[j] != labels_spectral[pred_j]:
                metastate_count[pred_j, labels_spectral[pred_j], labels_spectral[j]] += 1

    cluster_ids = {}
    for j in range(num_clusters):
        ids = np.where(labels_spectral == j)[0]
        cluster_ids[j] = ids

    (state_inds, __, __) = np.where(metastate_count == np.max(metastate_count))
    (state_inds, __, __) = np.where(metastate_count == 47)
    print('Goal States are:')
    for s in state_inds:
        print(states[s])

    # Compute number of path state s is on from metastate phi(s) to some other metastate
    # assuming metastate for all other states are fixed
    numpaths_count = np.zeros((num_states, num_clusters))
    # we run through every path s->t and add 1 to corresponding states w along the path
    # such that phi(s)=phi(w) and phi(w) != phi(t)
    for s in range(num_states): # run loop for each state
        for t in range(num_states): # run loop for each state
            meta_s = labels_spectral[s]
            meta_t = labels_spectral[t]
            if meta_s != meta_t: # only search path if s and t are in different metastates
                pred = t # start with last node
                while pred != s and pred != -9999:
                    pred = predecessors[s,pred] # get next predessor along path from s->t
                    if pred != -9999: # check that there is a path from s to pred
                        meta_pred = labels_spectral[pred]
                        if meta_s == meta_pred and meta_pred != meta_t:
                            numpaths_count[pred, meta_pred] += 1 # pred is in same metastate as s which is different from t

    # Greedy method to choose m subgoals for cluster
    diversity = shortest_dist.copy()
    for i in range(num_states):
        for j in range(num_states):
            diversity[i,j] = np.minimum(diversity[i,j], diversity[j,i])

    m = 2 # number of subgoals per cluster
    rho = 50 # tradeoff penalty between subgoals that lead to new meta-states and diversity of subgoals within meta-state
    eps_add_subgoal = 0.1 # percent increase in objective required to add another subgoal
    subgoal_store=[]
    for cluster in range(num_clusters):
        subgoals = [np.argmax(numpaths_count[:, cluster])] # get first subgoal
        inds = list(range(num_states))
        obj = numpaths_count[subgoals[0], cluster] # initial objective
        for i in range(m-1): # get m-1 more subgoals
            new_obj = obj
            inds.remove(subgoals[i])
            score_temp = numpaths_count[inds, cluster]
            for j in range(i+1):
                for k in range(len(inds)):
                    score_temp[k] += rho*diversity[inds[k], subgoals[j]]
            new_obj += max(score_temp)
            if (new_obj-obj)/obj >= eps_add_subgoal:
                subgoals.append(inds[np.argmax(score_temp)])
                obj = new_obj
            else:
                break
        print('The subgoals for cluster '+str(cluster)+' are:')
        for i in range(m):
            print('state '+str(subgoals[i])+ ' which is '+ str(states[subgoals[i]]))
        subgoal_store.append(subgoals)

    x = np.zeros((num_states,1))
    y = np.zeros((num_states,1))
    goal_state = 0 # this keeps track of the single goal state
    for i in range(num_states):
        s = states[i]
        (x_temp, y_temp) = s.get_data()
        x[i] = x_temp
        y[i] = y_temp
        if x[i] == grid_dim and y[i] == grid_dim:
            goal_state = i

    # cluster 0 is red
    # cluster 1 is green
    # cluster 2 is blue
    # cluster 3 is yello
    # bottom left in green is (1,1)

    #colors = ['red','green','blue','yellow']
    fig, ax = plt.subplots()
    inds = list(np.where(labels_spectral==0)[0])
    for i in subgoal_store[0]:
        inds.remove(i)
        if i == subgoal_store[0][0]:
            ax.scatter(x[i],y[i], marker='x',c='red', s=100)
        else:
            ax.scatter(x[i],y[i], marker='x',c='red')
    ax.scatter(x[inds],y[inds], marker='o',c='red')
    inds = list(np.where(labels_spectral==1)[0])
    for i in subgoal_store[1]:
        inds.remove(i)
        if i == subgoal_store[1][0]:
            ax.scatter(x[i],y[i], marker='x',c='green', s=100)
        else:
            ax.scatter(x[i],y[i], marker='x',c='green')
    ax.scatter(x[inds],y[inds], marker='o',c='green')
    inds = list(np.where(labels_spectral==2)[0])
    for i in subgoal_store[2]:
        inds.remove(i)
        if i == subgoal_store[2][0]:
            ax.scatter(x[i],y[i], marker='x',c='blue', s=100)
        else:
            ax.scatter(x[i],y[i], marker='x',c='blue')
    ax.scatter(x[inds],y[inds], marker='o',c='blue')
    inds = list(np.where(labels_spectral==3)[0])
    for i in subgoal_store[3]:
        inds.remove(i)
        if i == subgoal_store[3][0]:
            ax.scatter(x[i],y[i], marker='x',c='yellow', s=100)
        else:
            ax.scatter(x[i],y[i], marker='x',c='yellow')
    if goal_state in inds:
        inds.remove(goal_state)
        ax.scatter(x[goal_state],y[goal_state], marker='d',c='yellow', s=100)
    ax.scatter(x[inds],y[inds], marker='o',c='yellow')
    plt.xticks([])
    plt.yticks([])
    plt.savefig(os.path.join('figures','four_rooms_strategic_states.jpg'), bbox_inches='tight')