import gridworld_json2vec as gridworld_json2vec

import pandas as pd
import json
import torch
import importlib.resources
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import os
import pickle


from gridworld_json2vec import data_info

def read_episode(ep_root,ep_len):
    ep_data = []
    fnames = list(filter(lambda x: '.json' in x,os.listdir(ep_root)))
    fnames = sorted(fnames,key = lambda x: int(x.split(".")[0])) # numerical sorting of filenames
    for f in fnames:
        with open(ep_root/Path(f),'r') as f:
            json_obj = json.load(f) 

        vec = gridworld_json2vec.process(json_obj)

        ep_data.append(vec)
    return ep_data



def read_split(splits_info,split):
    data = []
    ep_lengths = []

    for i in range(len(splits_info)):
        name = splits_info.iloc[i]['episode']
        ep_len = splits_info.iloc[i]['num_frames']
        s = splits_info.iloc[i]['split']



        if s != split :#or 'normal' not in name:
            continue

        ep_data = read_episode(DATASET_ROOT/Path(name),ep_len)
        data.extend(ep_data)
        ep_lengths.append(ep_len)

    return np.array(data),ep_lengths


if __name__=='__main__':


    DATASET_ROOT = Path("/home/plymper/data/polycraftv2")

    splits_info = pd.read_csv(DATASET_ROOT/Path('splits.csv'))
    print(splits_info)
    

    
    train_data, ep_lengths = read_split(splits_info,'train')

    scaler = StandardScaler()

    nume_nodes = train_data[:,data_info['nume_entries']]
    nume_nodes = scaler.fit_transform(nume_nodes)
    train_data[:,data_info['nume_entries']] = nume_nodes

    train_set = dict(data = train_data,ep_lengths = ep_lengths, scaler= scaler)



    valid_data, ep_lengths = read_split(splits_info,'valid')

    nume_nodes = valid_data[:,data_info['nume_entries']]
    nume_nodes = scaler.transform(nume_nodes)
    valid_data[:,data_info['nume_entries']] = nume_nodes

    valid_set = dict(data = valid_data,ep_lengths = ep_lengths,scaler=scaler)


    with open(DATASET_ROOT/Path("json_normal_train_data.pkl"),'wb') as f:
        pickle.dump(train_set,f)

    with open(DATASET_ROOT/Path("json_normal_valid_data.pkl"),'wb') as f:
        pickle.dump(valid_set,f)


    