import os
import sys
sys.path.insert(0, os.path.join(os.getcwd(), 'common'))

from helper_functions import create_subgoal_paths
from minipacman_run_flow import minipacman_run_flow

import numpy as np
import random
import pickle
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd

from common.multiprocessing_env import SubprocVecEnv
from common.minipacman import MiniPacman
from common.actor_critic import ActorCritic, RolloutStorage
from common.deepmind import update_2d_pos
from copy import deepcopy

from scipy.sparse.csgraph import johnson # for computing shortest paths
from sklearn.cluster import SpectralClustering

import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, required=True, help="mode to use: hunt | eat (REQUIRED)")
parser.add_argument("--scenario", type=int, required=True, help="scenario to run (1, 2, or 3) (REQUIRED)")
args = parser.parse_args()

mode = args.mode
scenario = args.scenario

USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)

if __name__ == '__main__':

    files = os.listdir(r'saved_boards')
    file = [f for f in files if mode in f and str(scenario) in f][0]
    z = file.split('_')
    scenario_steps = int(z[3][:-5])

    filename_states = 'pacman_'+mode+'_scenario'+str(scenario)+'_'+str(scenario_steps)+'steps'
    filename_model = 'actor_critic_' + mode +'_10_by_7'

    max_steps = 6
    div = int(np.ceil(scenario_steps / 10))

    steps = 0

    while steps + max_steps <= scenario_steps:

        print('Working on:')
        print('Step ' + str(steps) + ' from '+ filename_states)

        num_clusters = 5
        num_subgoals = 3
        rho = 0.1
        use_count_subgoals = False
        eps_add_goal = 0.1
        eta = 100
        projection_type = 1

        random.seed(121)
        np.random.seed(1211)

        (subgoals_list, shortest_dist, predecessors, transition_matrix, env_list_final, state_list_final, number_steps_final, end_states_list, labels_spectral, diversity, ind_initial) = minipacman_run_flow(filename_states, filename_model, steps, max_steps, num_clusters, num_subgoals, rho, use_count_subgoals, eps_add_goal, 1, eta, projection_type)

        params = {}
        params['min_moves_from_pacman'] = 2
        all_paths = create_subgoal_paths(labels_spectral, subgoals_list, env_list_final, transition_matrix, number_steps_final, params)

        # plot all the paths of nodes to subgoals per cluster
        plt.clf()
        count = 1
        num_paths = len([len(x[i]) for x in all_paths for i in range(len(x))])
        if num_paths > 0:
            max_path = max([len(x[i]) for x in all_paths for i in range(len(x))]) # maximum length of a path in all_paths
        else:
            max_path = 0
        # only create directory for figures if there are figures to put
        dirname = mode+'_scenario'+str(scenario)+'_'+str(steps)+'_of_'+str(scenario_steps)+'steps'
        if max_path > 0  and not os.path.exists(os.path.join('figures','strategic_states', dirname)):
            os.mkdir(os.path.join('figures','strategic_states', dirname))
        for i in range(len(all_paths)): # loop over clusters
            for j in range(len(all_paths[i])): # loop over subgoals in cluster i
                path_temp = all_paths[i][j]
                cluster_id = labels_spectral[path_temp[0]]
                for k in range(max_path): # this is in order to add count for shorter paths
                    if k < len(path_temp):
                        plt.xticks([])
                        plt.yticks([])
                        if k==0:
                            plt.ylabel('Cluster ' + str(i),fontsize=40)
                        (__, __, img) = env_list_final[path_temp[k]].observation()
                        if k == (len(path_temp)-1): # this is the subgoal so we make it pink
                            [inds_x, inds_y] = np.where(np.multiply(np.multiply(img[:,:,0] == 1, img[:,:,1] == 1), img[:,:,2] == 1))
                            img[inds_x, inds_y, 0] = 221./256
                            img[inds_x, inds_y, 1] = 160./256
                            img[inds_x, inds_y, 2] = 221./256

                        plt.imshow(img)
                        savefile = mode+'_scenario'+str(scenario)+'_'+str(steps)+'_of_'+str(scenario_steps)+'steps_Cluster'+str(cluster_id)+'_Row'+str(j)+'_Column'+str(k)+'.jpg'
                        plt.savefig(os.path.join('figures','strategic_states', dirname,savefile), bbox_inches='tight')
                        plt.close()
                    count += 1
        steps += div