import os
import sys

import numpy as np
import random
import pickle
import argparse

import torch
import utils

from copy import deepcopy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

from helper_functions import create_subgoal_paths
from doorkey_run_flow import doorkey_run_flow

import numpy as np
import random
import pickle

import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, required=True, help="mode to use: locked | unlocked (REQUIRED)")
args = parser.parse_args()

mode = args.mode

USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)

if __name__ == '__main__':

    files = os.listdir(r'saved_boards')
    if mode == 'unlocked':
        file = [f for f in files if 'unlocked' in f][0]
        z = file.split('_')
        scenario_steps = int(z[2][:-5])
        model = "DoorKey_unlocked"
    else: # regular locked DoorKey
        file = [f for f in files if 'unlocked' not in f][0]
        z = file.split('_')
        scenario_steps = int(z[1][:-5])
        model = "DoorKey"

    filename_model = model
    max_steps = 6
    div = int(np.ceil(scenario_steps / 10))

    steps = 0
    filename_states = model+'_'+str(scenario_steps)+'steps'

    while steps + max_steps <= scenario_steps:
        print('Working on:')
        print('Step ' + str(steps) + ' from '+ filename_states)

        num_clusters = 5
        num_subgoals = 3
        rho = 1.0
        use_count_subgoals = False
        eps_add_goal = 0.1
        eta = 100
        projection_type = 1

        random.seed(121)
        np.random.seed(1211)

        (subgoals_list, shortest_dist, predecessors, transition_matrix, env_list_final, state_list_final, number_steps_final, end_states_list, labels_spectral, diversity, ind_initial) = doorkey_run_flow(filename_states, filename_model, steps, max_steps, num_clusters, num_subgoals, rho, use_count_subgoals, eps_add_goal, 1, eta, projection_type)

        all_paths = create_subgoal_paths(labels_spectral, subgoals_list, env_list_final, transition_matrix, number_steps_final)

        # plot all the paths of nodes to subgoals per cluster
        plt.clf()
        num_paths = len([len(x[i]) for x in all_paths for i in range(len(x))])
        if num_paths > 0:
            max_path = max([len(x[i]) for x in all_paths for i in range(len(x))]) # maximum length of a path in all_paths
        else:
            max_path = 0

        # only create directory for figures if there are figures to put
        dirname = model+'_'+str(steps)+'_of_'+str(scenario_steps)+'steps'
        if max_path > 0  and not os.path.exists(os.path.join('figures','strategic_states', dirname)):
            os.mkdir(os.path.join('figures','strategic_states', dirname))
        for i in range(len(all_paths)): # loop over clusters
            for j in range(len(all_paths[i])): # loop over subgoals in cluster i
                path_temp = all_paths[i][j]
                cluster_id = labels_spectral[path_temp[0]]
                for k in range(max_path): # this is in order to add count for shorter paths
                    if k < len(path_temp):
                        plt.xticks([])
                        plt.yticks([])
                        if k==0:
                            plt.ylabel('Cluster ' + str(i),fontsize=40)
                        # get Doorkey image
                        env_temp = env_list_final[path_temp[k]]
                        img = env_temp.get_obs_render(env_temp.gen_obs()['image'])/256.
                        (m,n,p) = img.shape
                        new_img = np.ones((m+20, n+20,3))
                        new_img[10:(m+10), 10:(n+10),:] = img
                        img = new_img
                        if k == (len(path_temp)-1): # this is the subgoal so we make it pink
                            [inds_x, inds_y] = np.where(np.multiply(np.multiply(img[:,:,0] == 1, img[:,:,1] == 1), img[:,:,2] == 1))
                            img[inds_x, inds_y, 0] = 221./256
                            img[inds_x, inds_y, 1] = 160./256
                            img[inds_x, inds_y, 2] = 221./256
                        plt.imshow(img)
                        savefile = model+'_'+str(steps)+'_of_'+str(scenario_steps)+'steps_Cluster'+str(cluster_id)+'_Row'+str(j)+'_Column'+str(k)+'.jpg'
                        plt.savefig(os.path.join('figures','strategic_states', dirname,savefile), bbox_inches='tight')
                        plt.close()
        steps += div