import numpy as np
import json
import pickle as pkl
import sys
import torch
from sklearn import datasets
sys.path.append("../")
from dataset import monopoly_json2vec
from natsort import natsorted
import os
"""
def dice_novelty(curr,prev, player):
    labels = {k:0 for k in curr.keys()}
    position = f"{player}_current_position"
    die = f"{player}_die"
    
    labels[position]=1
    labels[die]=1

    curr[position]+=np.random.randint(1,3)

    return curr,prev,labels



def has_player_moved(curr,prev,player):
    applied = False
    entry = f"{player}_current_position"
    if curr[entry]!= prev[entry]:
        applied = True
    
    return applied
    
    

def inject_novelty(data):

    device = data.last_state.get_device()
    curr = data.last_state.numpy()[0]
    past_idx = -1

    prev = data.x[0,:,past_idx].numpy() # first step in window -- may want to make this random
    
    curr = monopoly_json2vec.name_features(curr)
    prev = monopoly_json2vec.name_features(prev)
    
    labels = {k:0 for k in curr.keys()}
    

    for p in monopoly_json2vec.players:
        if has_player_moved(curr,prev,p):
            curr,prev,labels = dice_novelty(curr,prev,p)
            break
    
    curr = monopoly_json2vec.named2vec(curr)
    prev = monopoly_json2vec.named2vec(prev)
    labels = np.array(list(labels.values()))
    
    data.last_state = torch.tensor(curr)#.to(device)
    data.x[0,:,past_idx] = torch.tensor(prev)#.to(device)
    data.labels = torch.tensor(labels)#.to(device)

    return data

if __name__ == '__main__':
    dataset_root="/home/plymper/data/polycraftv2/normal/val/"
    folders = os.listdir(dataset_root)
    print(folders)
    folders = natsorted(list(filter(lambda x: str(x).isnumeric(),folders)))
    print(folders)
    exit()
    for k,v in errors.items():
        jsons = natsorted(list(filter(lambda x:".json" in x, os.listdir(data_path+k))))
        episodes[k]=[]
        for j in jsons:
            with open(data_path+k+"/"+j,'r') as f:
                episodes[k].append(json.load(f))

"""







def cash_change_condition(node_feats,last_state, cash_idx):

    print(node_feats[cash_idx,-1],last_state[cash_idx])
    return not node_feats[cash_idx,-1] == last_state[cash_idx]


def cash_change_novelty(node_feats,last_state, normalizer=None, node_info=None, old_node_info=None,**kwargs):
    '''
        Change player's cash value. 

        Returns changed state and array of indeces of features changed
    '''
    player_cash_idx = [24,44,38,32]
    
    idx_dict = dict(zip(old_node_info['nume_nodes'],node_info['nume_nodes']))

    player_cash_idx = [idx_dict[i] for i in player_cash_idx]

    player_idx = np.random.randint(4)
    feat_idx = player_cash_idx[player_idx]
    
    if not cash_change_condition(node_feats,last_state,feat_idx):
        return last_state, np.array([])

    temp = list(node_info['nume_nodes']).index(feat_idx)
    mean,std = normalizer['mean'][temp],normalizer['std'][temp]

    curr_val = (last_state[feat_idx] +mean)*std
    
    delta = np.random.choice([-200,-100,100,200,300,400])
    #delta = np.random.choice([-50,50])
    new_val = (curr_val+delta-mean)/std
    last_state[feat_idx]=new_val
    
    return last_state, np.array([feat_idx])

def player_steals_cash_novelty(last_state, normalizer=None, node_info=None, old_node_info=None, **kwargs):
    '''
    One player steals cash from another
    '''

    player_cash_idx = [24,44,38,32]
    
    idx_dict = dict(zip(old_node_info['nume_nodes'],node_info['nume_nodes']))

    player_cash_idx = [idx_dict[i] for i in player_cash_idx]

    feat_idx1,feat_idx2 = np.random.choice(player_cash_idx,size=(2,),replace=False)#player_cash_idx[player_idx]
    
    temp = list(node_info['nume_nodes']).index(feat_idx1)
    mean,std = normalizer['mean'][temp],normalizer['std'][temp]

    
    delta = np.random.choice([300,400,500])
    delta = (delta-mean)/std

    last_state[feat_idx1]= last_state[feat_idx1]-delta
    last_state[feat_idx2]= last_state[feat_idx2]+delta
    

    return last_state,np.array([feat_idx1,feat_idx2])

def move_player_condition(node_feats,last_state, pos_idx):

    print(node_feats[pos_idx,-1],last_state[pos_idx])
    return not (node_feats[pos_idx,-1] == last_state[pos_idx])

def move_player_novelty(node_feats,last_state, all_node_array=None, node_info=None, old_node_info=None, **kwargs):
    '''
        Randomly move player

        Returns changed state and array of indeces of features changed
    '''
    player_pos_idx = [15,40,34,27]
    idx_dict = dict(zip(old_node_info['nume_nodes'],node_info['nume_nodes']))
    player_pos_idx = [idx_dict[i] for i in player_pos_idx]

    feat_idx = np.random.choice(player_pos_idx)
    if not move_player_condition(node_feats,last_state,feat_idx):
        return last_state, np.array([])

    possible_vals = [-1.61299089, -1.52384452, -1.43469816, -1.3455518 , -1.25640543,
       -1.16725907, -1.07811271, -0.98896635, -0.89981998, -0.81067362,
       -0.72152726, -0.63238089, -0.54323453, -0.45408817, -0.3649418 ,
       -0.27579544, -0.18664908, -0.09750271, -0.00835635,  0.08079001,
        0.16993638,  0.25908274,  0.3482291 ,  0.43737546,  0.52652183,
        0.61566819,  0.70481455,  0.79396092,  0.88310728,  0.97225364,
        1.15054637,  1.23969273,  1.3288391 ,  1.41798546,  1.50713182,
        1.59627819,  1.68542455,  1.77457091,  1.86371728]

    #ind = possible_vals.index(last_state[feat_idx])
    new_val = last_state[feat_idx]
    while new_val==last_state[feat_idx]:
        new_val = np.random.choice(possible_vals)

    last_state[feat_idx]=new_val#last_state[feat_idx]-np.random.randn()/10#possible_vals[ind-1]
    
    return last_state, np.array([feat_idx])

def all_players_lose_novelty(last_state,node_info=None, old_node_info=None, **kwargs):
    '''
        Change all player's status to lost. 

        Returns changed state and array of indeces of features changed
    '''
    player_status_idx = [17,41,35,28]
    
    idx_dict = dict(zip(old_node_info['cat_nodes'],node_info['cat_nodes']))
    player_status_idx = [idx_dict[i] for i in player_status_idx]


    anomaly_idx = []
    for status_idx in player_status_idx:
        if last_state[status_idx]!=45:
            last_state[status_idx]=45
            anomaly_idx.append(status_idx)
        
    
    return last_state, np.array(anomaly_idx)



def inject_novelty(node_feats,last_state, novelty_type='cash_change',task='monopoly', **kwargs):
    label=0
    node_labels = np.zeros(last_state.shape)
    if task=='monopoly':
        if novelty_type=='cash_change':
            last_state,nov_idx= cash_change_novelty(node_feats,last_state,**kwargs)
        elif novelty_type=='cash_steal':
            last_state,nov_idx= player_steals_cash_novelty(last_state,**kwargs)
        elif novelty_type=='move_player':
            last_state,nov_idx= move_player_novelty(node_feats,last_state,**kwargs)
        elif novelty_type=='all_players_lose':
            last_state,nov_idx= all_players_lose_novelty(last_state,**kwargs)
        else:
            raise ValueError("Novelty Not available")
        
        if len(nov_idx)>0:
            label=1
            
            node_labels[nov_idx]=1

    elif task=='gridworld':
        pass

    return last_state, label,node_labels
