from torch import inverse
import dataset.json_graph as json_graph
import json
import os
import numpy as np
import pickle
import tqdm

def process_data_from_path(data_path, task, jgraph, epi_files):
    node_feat = []
    graph_lists = []
    epi_length = []

    for epi in tqdm.tqdm(range(len(epi_files))):
        epi_length.append(len(epi_files[epi]))

        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
            print(fname)
            with open(os.path.join(data_path, fname)) as file:
                json_obj = json.load(file)

            graph = jgraph.process(json_obj, task)

            node_feat.append(graph["node_value"])
            graph_lists.append(graph["edge_list"])
        
    # (total_steps, num_nodes)
    node_feat = np.stack(node_feat, axis=0)

    data = dict(node_feat=node_feat, graph_lists=graph_lists, epi_length=epi_length, node_type=jgraph.node_type)
    return data

def create_file_list(data_path, num_epi_train):
    epi_files = [ list() for i in range(num_epi_train)]
    count=0
    for fname in os.listdir(data_path):
        if fname.endswith(".json"):
            epi = int(fname[:fname.index('_')]) - 1
            print(epi)
            if(epi==-1) :
                exit()
            
            step = int(fname[(fname.index('_') + 1):fname.index('.')]) - 1
            if count < num_epi_train:
                epi_files[epi].append(step)
            count+=1
   
    for epi in range(len(epi_files)):
        epi_files[epi].sort()
    return epi_files

if __name__ == "__main__":
    #train_data_path = '/home/plymper/data/gridworldsData/train_data'
    train_data_path = '/home/plymper/data/polycraftv2/normal/jsons/'
    val_data_path = '/home/plymper/data/polycraftv2/normal/val/'
    
    #train_data_path = '/home/plymper/data/polycraftv2/mapless'
    #val_data_path = '/home/plymper/data/polycraftv2/val_mapless'
    #test_data_path = '/home/plymper/data/gridworldsData/novel_woodgift_persistent'

    task = "polycraft"
    
    jgraph = json_graph.JsonToGraph()

    num_epi_train = len(list(filter(lambda x: 'json' in x,os.listdir(train_data_path))))
    #num_epi_train = max([int(i.split("_")[0]) for i in num_epi_train])


    epi_files = create_file_list(train_data_path, num_epi_train)
    
    for epi in tqdm.tqdm(range(len(epi_files)),total=80):
        
        for step in epi_files[epi]:
            fname = str(epi+1) + '_' + str(step+1) + ".json"
                
            with open(os.path.join(train_data_path, fname)) as file:
                json_obj = json.load(file)

            jgraph.process(json_obj, task)


    jgraph.switch_mode("processing")

    # tidy up before storing the json objects
    jgraph.tidyup()
    with open(os.path.join(train_data_path,"jgraph.pkl"), 'wb') as handle:
        pickle.dump(jgraph, handle)

    data = process_data_from_path(train_data_path, task, jgraph, epi_files)
    
    
    with open(os.path.join(train_data_path,"gridworld.pkl"), 'wb') as handle:
        pickle.dump(data, handle)
    
    
    #process validation data
    with open(os.path.join(train_data_path,"jgraph.pkl"), 'rb') as handle:
        jgraph=pickle.load(handle)

    

    num_epi_train = len(list(filter(lambda x: 'json' in x,os.listdir(val_data_path))))
    #num_epi_train = max([int(i.split("_")[0]) for i in num_epi_train])
    epi_files = create_file_list(val_data_path, num_epi_train)

    data = process_data_from_path(val_data_path, task, jgraph, epi_files)
    

    with open(os.path.join(val_data_path,"gridworld.pkl"), 'wb') as handle:
        pickle.dump(data, handle)
        #jgraph.tidyup()
    with open(os.path.join(val_data_path,"jgraph.pkl"), 'wb') as handle:
        pickle.dump(jgraph, handle)
    exit()
    
    num_epi_train = len(list(filter(lambda x: 'json' in x,os.listdir(test_data_path))))
    epi_files = create_file_list(test_data_path, num_epi_train)

    data = process_data_from_path(test_data_path, task, jgraph, epi_files)
    

    with open(os.path.join(test_data_path,"gridworld.pkl"), 'wb') as handle:
        pickle.dump(data, handle)
        #jgraph.tidyup()
    with open(os.path.join(test_data_path,"jgraph.pkl"), 'wb') as handle:
        pickle.dump(jgraph, handle)
    