import numpy as np
import json
import pickle as pkl
import sys
import torch
from sklearn import datasets
sys.path.append("../")
from dataset import monopoly_json2vec
from natsort import natsorted
import os

def dice_novelty(curr,prev, player):
    labels = {k:0 for k in curr.keys()}
    position = f"{player}_current_position"
    dice = f"dice"
    
    
    labels[position]=1
    labels[dice]=1
    change = np.random.choice([-1,0,1])
    curr[dice]=curr[dice]+change
    change = np.random.choice([-1,0,1])
    curr[position]=curr[position] + change

    #print(curr,prev,labels, file=sys.stderr)
    return curr,prev,labels


def dice_novelty_2_dice(curr,prev, player):
    labels = {k:0 for k in curr.keys()}
    position = f"{player}_current_position"
    die1 = f"{player}_die1"
    die2 = f"{player}_die2"
    
    labels[position]=1
    labels[die1]=0
    labels[die2]=0
    change = np.random.randint(1,6)
    curr[die1]=change
    change = np.random.randint(1,6)
    curr[die2]=change
    change = np.random.randint(-6,6)
    curr[position]=curr[position] #+ change

    #print(curr,prev,labels, file=sys.stderr)
    return curr,prev,labels

def position_novelty(curr,prev, player):
    labels = {k:0 for k in curr.keys()}
    position = f"{player}_current_position"
    
    labels[position]=1
    change = np.random.randint(-6,6)
    curr[position]+=change
    
    print(curr, file=sys.stderr)
    print(prev, file=sys.stderr)
    print(labels, file=sys.stderr)
    print("---------------", file=sys.stderr)
    
    return curr,prev,labels

def status_novelty(curr,prev,player):
    labels = {k:0 for k in curr.keys()}
    entry = f"{player}_status"
    
    for i in range(1,5):
        entry = f"player_{i}_status"
        if curr[entry]==1:
            labels[entry]=1
        
    curr[entry]= 1-curr[entry]
    labels[entry]=1


    return curr,prev,labels

def has_player_moved(curr,prev,player):
    applied = False
    entry = f"{player}_current_position"
    if curr[entry]!= prev[entry]:
        applied = True
    
    return applied
    
    
def has_cash_changed(curr,prev,player):
    entry = f"{player}_current_cash"
    if curr[entry]!=prev[entry]:
        return True
    return False

def cash_change_novelty(curr,prev,player):
    labels = {k:0 for k in curr.keys()}
    entry = f"{player}_current_cash"
    curr[entry]+=np.random.choice([-200,-100,100,200])
    labels[entry]=1
    return curr,prev,labels

def has_status_changed(curr,prev,player):
    entry = f"{player}_status"
    if curr[entry]!=prev[entry]:
        return True
    return False

def has_property_changed(curr,prev,player):
    entry = f"{player}_num_properties"
    if curr[entry]!=prev[entry]:
        return True
    return False

def property_novelty(curr,prev,player):
    labels = {k:0 for k in curr.keys()}
    cash = f"{player}_current_cash"
    prop = f"{player}_num_properties"
    curr[cash]=curr[cash]/2
    curr[prop]=curr[prop]+1
    labels[cash]=1
    labels[prop]=1


    return curr,prev,labels

def inject_novelty(data):

    device = data.last_state.get_device()
    curr = data.last_state.numpy()[0]
    past_idx = -1

    prev = data.x[0,:,past_idx].numpy() # first step in window -- may want to make this random
    
    curr = monopoly_json2vec.name_features(curr)
    prev = monopoly_json2vec.name_features(prev)
    
    labels = {k:0 for k in curr.keys()}
    

    for p in monopoly_json2vec.players:
        
        if np.random.rand()<0.6:
            if has_player_moved(curr,prev,p):
                curr,prev,labels = dice_novelty(curr,prev,p)
                break
        else:
            if has_status_changed(curr,prev,p):
                curr,prev,labels = status_novelty(curr,prev,p)
                break
        #if has_property_changed(curr,prev,p):
        #    curr,prev,labels = property_novelty(curr,prev,p)
        #    break
        #if not has_player_moved(curr,prev,p):
        #    curr,prev,labels = position_novelty(curr,prev,p)
        #    break
        
        #if has_cash_changed(curr,prev,p):
        #    curr,prev,labels = cash_change_novelty(curr,prev,p)
        #    break

    
    curr = monopoly_json2vec.named2vec(curr)
    prev = monopoly_json2vec.named2vec(prev)
    labels = np.array(list(labels.values()))
    
    data.last_state = torch.tensor(curr)#.to(device)
    data.x[0,:,past_idx] = torch.tensor(prev)#.to(device)
    data.labels = torch.tensor(labels)#.to(device)

    return data

if __name__ == '__main__':
    dataset_root="/home/plymper/data/polycraftv2/normal/val/"
    folders = os.listdir(dataset_root)
    print(folders)
    folders = natsorted(list(filter(lambda x: str(x).isnumeric(),folders)))
    print(folders)
    exit()
    for k,v in errors.items():
        jsons = natsorted(list(filter(lambda x:".json" in x, os.listdir(data_path+k))))
        episodes[k]=[]
        for j in jsons:
            with open(data_path+k+"/"+j,'r') as f:
                episodes[k].append(json.load(f))