import json
import numpy as np
import copy
import sys
#nume_entries= [ 0,  1,  2,  4,  5,  6,  7,  8,  9, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 25, 26, 27]
#bin_entries= [ 3, 10, 17, 24]
#nume_entries= [ 0,  1,  2, 3,  4,  5,  6,  7,  8,  9,10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
#bin_entries=[]

#nume_entries= [ 0, 1, 2, 3, 5, 7, 8, 9, 10, 12, 14, 15, 16, 17, 19, 21, 22, 23, 24, 26]
#bin_entries= [ 4, 6, 11, 13, 18, 20, 25, 27]
#cat_entries = []
#cat_ranges = []
#random_vars = [0,1,7,8,14,15,21,22]
#random_entries = {"numerical":[],"binary":[],"categorical":[]}
#
#nume_entries= [ 0, 1, 2, 6, 7, 8, 12, 13, 14, 18, 19, 20]
#bin_entries= [ 3, 4, 5, 9, 10, 11, 15, 16, 17, 21, 22, 23]
#cat_entries = []
#cat_ranges = []
#random_vars = [0,6,12,18]
#random_entries = {"numerical":[],"binary":[],"categorical":[]}
nume_entries= [0,1,3,5,7]
bin_entries= [2,4,6,8]
cat_entries = []
cat_ranges = []
random_vars = [0]
random_entries = {"numerical":[],"binary":[],"categorical":[]}


for v in random_vars:
    if v in nume_entries:
        key = 'numerical'
        val = nume_entries.index(v)
    elif v in bin_entries:
        key='binary'
        val = bin_entries.index(v)
    elif v in cat_entries:
        val = cat_entries.index(v)
        key='categorical'
    else:
        raise Exception("No type for given variable")
    
    random_entries[key].append(val)




#_entry_names = ["die","current_position","current_cash", "in_jail","num_utilities_possessed","num_total_houses","num_total_hotels","is_current_turn"]
#_entry_names=["die1","die2","current_position","current_cash", "in_jail","num_properties","is_current_turn"]
#_entry_names=["dice","current_position","current_cash", "status","jail_chance_card","jail_community_card"]
_entry_names=["current_position", "status"]


players = [f'player_{i}' for i in range(1,5)]
feature_names = ['dice']+[p+"_"+n for p in players for n in _entry_names]

feature_inds = {f:i for i,f in enumerate(feature_names)}


data_info = {"cat_entries":np.array(cat_entries),'cat_ranges':np.array(cat_ranges),"nume_entries":np.array(nume_entries), "bin_entries":np.array(bin_entries),"random_entries":random_entries}

def types2vec(types):
    pred = np.zeros(len(nume_entries)+len(bin_entries))
    pred[nume_entries]=types['numerical']
    pred[bin_entries]=1*types['binary']
    return pred

def name_features(vec):
    named_features = copy.deepcopy(feature_inds)
    for k in named_features:
        named_features[k]=vec[named_features[k]]
    return named_features

def named2vec(named):
    vec = list(named.values())
    return np.array(vec)


def player_vector(player,jobj):
    status = 1 if  jobj['prev_state']['players'][player]['status'] =='current_move' else 0

    current_position = jobj['prev_state']['players'][player]['current_position']
    values = [current_position,status]
    vtype = ["numerical",'binary']

    return values




def player_vector_(player,jobj):
    status = 1 if  jobj['prev_state']['players'][player]['status'] =='current_move' else 0
    try:
        die1 = jobj['prev_state']['die_sequence'][-1][0]#sum(jobj['true_next_state']['die_sequence'][-1])
        die2 = jobj['prev_state']['die_sequence'][-1][1]   
    except IndexError:
        die1 = 0
        die2 = 0

    die1 = die1 if status else 0
    die2 = die2 if status else 0

    current_position = jobj['prev_state']['players'][player]['current_position']
    current_cash = jobj['prev_state']['players'][player]['current_cash']
    
    has_jail_chance_card = 1 if jobj['prev_state']['players'][player]['has_get_out_of_jail_chance_card'] else 0
    has_jail_community_card = 1 if jobj['prev_state']['players'][player]['has_get_out_of_jail_community_chest_card'] else 0
    dice = die1+die2
    
    values = [dice,current_position,current_cash, status,has_jail_chance_card,has_jail_community_card]
    vtype = ["numerical","numerical","numerical","binary","binary","binary"]

    return values


def _player_vector(player,jobj):
    prev_pos = jobj['prev_state']['players'][player]['current_position']
    current_position = jobj['true_next_state']['players'][player]['current_position']
    current_cash = jobj['true_next_state']['players'][player]['current_cash']
    try:
        die1 = jobj['true_next_state']['die_sequence'][-1][0]#sum(jobj['true_next_state']['die_sequence'][-1])
        die2 = jobj['true_next_state']['die_sequence'][-1][1]   
    except IndexError:
        die1 = 0
        die2 = 0

    die1 = die1 if prev_pos!=current_position else 0
    die2 = die2 if prev_pos!=current_position else 0

    in_jail = int(jobj['true_next_state']['players'][player]['currently_in_jail'])
    num_utilities_possessed = jobj['true_next_state']['players'][player]['num_utilities_possessed']
    num_total_houses= jobj['true_next_state']['players'][player]['num_total_houses']
    num_total_hotels= jobj['true_next_state']['players'][player]['num_total_hotels']
    is_current_turn = 1 if die1!=0 else 0
    #in_jail=0
    #current_cash=0

    num_properties = num_utilities_possessed
    values = [die1,die2,current_position,current_cash, in_jail,num_properties,is_current_turn]
    vtype = ["numerical","numerical","numerical","numerical","binary","numerical","binary"]

    return values


def process(jobj):
    
    try:
        die1 = jobj['prev_state']['die_sequence'][-1][0]#sum(jobj['true_next_state']['die_sequence'][-1])
        die2 = jobj['prev_state']['die_sequence'][-1][1]   
    except IndexError:
        die1 = 0
        die2 = 0

    dice = die1+die2
    vectors=[dice]

    for p in players:
        v = player_vector(p,jobj)
        vectors+=v
    v = np.array(vectors)
    return v


if __name__ == "__main__":

    dataset_root = "/home/plymper/data/monopoly"
    
    for i in range(2,20):
        with open(f"{dataset_root}/episodes/0/{i}.json", 'r') as f:
            jobj = json.load(f)           
        

        vectors = []
        variable_types = []
        for p in players:
            v,t = player_vector(p,jobj)
            vectors+=v
            variable_types+=t
        vectors = np.array(vectors)
        variable_types = np.array(variable_types)
        print(vectors)
        print(variable_types)
        nume_nodes = np.where(variable_types=='numerical')
        bin_nodes = np.where(variable_types=='binary')
        
        print("Nume nodes")
        print(*nume_nodes[0],sep=', ')
        print("Bin nodes")
        print(*bin_nodes[0], sep=', ')
        exit()