import monopoly_json2vec as monopoly_json2vec

import pandas as pd
import json
import torch
import importlib.resources
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import os
import pickle

import monopoly_json2vec

from monopoly_json2vec import data_info

def condense(steps):
    last_steps = steps[-1]
    for p in ['player_1','player_2','player_3','player_4']:
        dice_idx = monopoly_json2vec.feature_inds[p+"_die"]
        pos_idx = monopoly_json2vec.feature_inds[p+"_current_position"]
        for i in range(0,len(steps)):
            if steps[i][dice_idx]!=0:
                last_steps[dice_idx]=steps[i][dice_idx]
                last_steps[pos_idx] = steps[i][pos_idx]

   
    return last_steps
                
    


def read_episode(ep_root,ep_len):
    ep_data = []
    fnames = list(filter(lambda x: '.json' in x,os.listdir(ep_root)))
    fnames = sorted(fnames,key = lambda x: int(x.split(".")[0])) # numerical sorting of filenames
    for f in fnames:
        with open(ep_root/Path(f),'r') as f:
            json_obj = json.load(f) 

        vec = monopoly_json2vec.process(json_obj)

        ep_data.append(vec)
    
    
    #condensed = []
#
    #for i in range(len(ep_data)):
    #    dice_idx = monopoly_json2vec.feature_inds["player_1_die"]
    #    turn_size = 1
    #    if ep_data[i][dice_idx]!=0:
    #        for j in range(i+1,len(ep_data)-1):                
    #            if ep_data[j+1][dice_idx]==0:
    #                turn_size+=1
    #            else:
    #                break
    #    condensed.append(condense(ep_data[i:i+turn_size]))
#
    #    i+=turn_size

            
    return ep_data



def read_split(splits_info,split):
    data = []
    ep_lengths = []

    for i in range(len(splits_info)):
        name = splits_info.iloc[i]['episode']
        ep_len = splits_info.iloc[i]['num_frames']
        s = splits_info.iloc[i]['split']



        if s != split :#or 'normal' not in name:
            continue

        ep_data = read_episode(DATASET_ROOT/Path(name),ep_len)
        data.extend(ep_data)
        ep_lengths.append(ep_len)

    return np.array(data),ep_lengths


if __name__=='__main__':

    DATASET_ROOT = Path("/home/plymper/data/monopoly/")
    #DATASET_ROOT = Path("/home/plymper/data/polycraftv2")

    splits_info = pd.read_csv(DATASET_ROOT/Path('splits.csv'))
    print(splits_info)
    

    
    train_data, ep_lengths = read_split(splits_info,'train')

    scaler = StandardScaler()

    nume_nodes = train_data[:,data_info['nume_entries']]
    nume_nodes = scaler.fit_transform(nume_nodes)
    train_data[:,data_info['nume_entries']] = nume_nodes

    train_set = dict(data = train_data,ep_lengths = ep_lengths, scaler= scaler)



    valid_data, ep_lengths = read_split(splits_info,'valid')

    nume_nodes = valid_data[:,data_info['nume_entries']]
    nume_nodes = scaler.transform(nume_nodes)
    valid_data[:,data_info['nume_entries']] = nume_nodes

    valid_set = dict(data = valid_data,ep_lengths = ep_lengths,scaler=scaler)


    with open(DATASET_ROOT/Path("json_normal_train_data.pkl"),'wb') as f:
        pickle.dump(train_set,f)

    with open(DATASET_ROOT/Path("json_normal_valid_data.pkl"),'wb') as f:
        pickle.dump(valid_set,f)

    