import numpy as np
import json
import pickle as pkl
import sys
import torch
from sklearn import datasets
sys.path.append("../")
from dataset import gridworld_json2vec, gridworld_json_info
from natsort import natsorted
import os

def recipe_novelty(curr,prev,recipe):
    out,inputs,out_quant = recipe
    curr[out]=0

    item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    while item == out:
        item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))

    curr[item] +=out_quant
    labels = {k:0 for k in curr.keys()}
    labels[item] = 1
    labels[out]=1
    
    for inp in inputs:
        labels[inp] = 1
        
    return curr,prev,labels


def stick_recipe_novelty(curr,prev): # instead of sticks, a random item is crafted
    curr["minecraft:stick"]=0
    item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    while item == "minecraft:stick":
        item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    curr[item] +=np.random.randint(4)
    labels = {k:0 for k in curr.keys()}
    labels['minecraft:stick'] = 1
    labels[item] = 1
    labels["minecraft:planks"] = 1
    return curr,prev,labels

def stick_recipe_novelty(curr,prev): # instead of sticks, a random item is crafted
    curr["polycraft:wooden_pogo_stick"]=0
    item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    while item == "polycraft:wooden_pogo_stick":
        item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    curr[item] +=np.random.randint(4)
    labels = {k:0 for k in curr.keys()}
    labels['polycraft:wooden_pogo_stick'] = 1
    labels[item] = 1
    labels["minecraft:planks"] = 1
    return curr,prev,labels

def is_stick_crafted(curr,prev):
    if curr["minecraft:stick"] - prev["minecraft:stick"] ==4:
        return True 

    return False

def is_pogo_crafted(curr,prev):
    if curr["polycraft:wooden_pogo_stick"] - prev["polycraft:wooden_pogo_stick"] ==1:
        return True 

    return False

def stick_recipe_novelty(curr,prev): # instead of planks, a random item is crafted
    curr["minecraft:stick"]=0
    item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    while item == "minecraft:stick":
        item = np.random.choice(list(gridworld_json_info.inverse_inventory_indeces.keys()))
    curr[item] +=np.random.randint(4)
    labels = {k:0 for k in curr.keys()}
    labels['minecraft:stick'] = 1
    labels[item] = 1
    labels["minecraft:planks"] = 1
    return curr,prev,labels

def is_stick_crafted(curr,prev):
    if curr["minecraft:sticks"] - prev["minecraft:sticks"] ==4:
        return True 

    return False

def is_recipe_crafted(curr,prev,recipe):
    out,inputs,out_quant = recipe
    applied = [(curr[out]-prev[out])==out_quant]
    for inp in inputs:
        
        itm = inp
        sz = inputs[inp]
        applied.append((curr[itm]-prev[itm])==-sz)

    
    return all(applied)
    
    

def inject_novelty(data,type="recipes"):

    device = data.last_state.get_device()
    curr = data.last_state.numpy()[0]
    past_idx = -1

    prev = data.x[0,:,past_idx].numpy() # first step in window -- may want to make this random
    
    curr = gridworld_json2vec.name_features(curr)
    prev = gridworld_json2vec.name_features(prev)
    
    labels = {k:0 for k in curr.keys()}
    
    for k,v in gridworld_json_info.recipes.items():
        recipe = k,v[0],v[1]
        if is_recipe_crafted(curr,prev,recipe):
            curr,prev,labels = recipe_novelty(curr,prev,recipe)
    
    curr = gridworld_json2vec.named2vec(curr)
    prev = gridworld_json2vec.named2vec(prev)
    labels = np.array(list(labels.values()))
    
    data.last_state = torch.tensor(curr)#.to(device)
    data.x[0,:,past_idx] = torch.tensor(prev)#.to(device)
    data.labels = torch.tensor(labels)#.to(device)

    return data

if __name__ == '__main__':
    dataset_root="/home/plymper/data/polycraftv2/normal/val/"
    folders = os.listdir(dataset_root)
    print(folders)
    folders = natsorted(list(filter(lambda x: str(x).isnumeric(),folders)))
    print(folders)
    exit()
    for k,v in errors.items():
        jsons = natsorted(list(filter(lambda x:".json" in x, os.listdir(data_path+k))))
        episodes[k]=[]
        for j in jsons:
            with open(data_path+k+"/"+j,'r') as f:
                episodes[k].append(json.load(f))

