import numpy as np
import random
import torch
from karel.world import World

def create_map_and_robot(h=18, w=18, max_wall_ratio=0.1, max_marker_ratio=0.3):
    _random_ratio = random.uniform
    _random_number = random.randint
    _random_direction = random.randint
    # initial
    map = np.zeros((h,w))
    map[0,:] = 6
    map[-1,:] = 6
    map[:,0] = 6
    map[:,-1] = 6
    index_pool = [ (k//w, k%w) for k in range(w, (h-1)*w) if k%w != 0 and (k+1)%w!=0]
    assert (h-2)*(w-2) == len(index_pool)
    position_num = len(index_pool)
    # wall
    wall_ratio = _random_ratio(0, max_wall_ratio)
    wall_num = int(np.floor(position_num * wall_ratio))
    walls = random.sample(index_pool, wall_num)
    for i in range(wall_num):
        map[walls[i]] = 5
        index_pool.remove(walls[i])
    # marker
    marker_ratio = _random_ratio(0, max_marker_ratio)
    while wall_ratio + marker_ratio > 1:
        marker_ratio = _random_ratio(0, max_marker_ratio)
    marker_num = int(np.floor(position_num * marker_ratio))
    markers = random.sample(index_pool, marker_num)
    for i in range(marker_num):
        map[markers[i]] = _random_number(1,10) + 6
    # robot
    robot = random.choice(index_pool)
    direction = _random_direction(0, 3)
    return map, robot, direction


def map2tensor(map):
    h, w = map.shape
    map_tensor = np.zeros((h*w, 17))
    map_tensor[np.arange(h*w), map.reshape(-1)] = 1
    return map_tensor.reshape(h, w, -1)[:,:,1:]


def get_map(h=18, w=18, max_wall_ratio=0.1, max_marker_ratio=0.3):
    map, robot, direction = create_map_and_robot(h, w, max_wall_ratio, max_marker_ratio)
    map = map2tensor(map.astype(np.int32))
    map[robot + (direction,)] = 1
    return map.transpose(2,0,1)

def write_dataset():
    path = 'datasets/karel/random_grids2.thdump'
    inp_data = []
    repeated = 0
    for i in range(8000000):
        inp_grid = torch.Tensor(get_map())
        inp_idx = inp_grid.view(-1).nonzero(as_tuple=False).view(-1).short().data
#        if inp_idx in inp_data:
#            repeated += 1
#            print('repeated:', repeated)
#            continue
        inp_data.append(inp_idx)
    torch.save(inp_data, path)

def grid_desc_to_tensor(grid_desc):
    IMG_FEAT = 5184
    IMG_DIM = 18
    IMG_SIZE = torch.Size((16, IMG_DIM, IMG_DIM))
    grid = torch.Tensor(IMG_FEAT).fill_(0)
    grid.index_fill_(0, grid_desc.long(), 1)
    grid = grid.view(IMG_SIZE)
    return grid

def load_dataset():
    path = 'datasets/karel/random_grids.thdump'
    inp_data = torch.load(path) 
    for inp_idx in inp_data:
        inp_grid = grid_desc_to_tensor(inp_idx) 
        world = World.fromPytorchTensor(inp_grid)

if __name__ == '__main__':
#    map, robot, direction = create_map_and_robot()
#    print('map')
#    print(map)
#    print('robot and direction:', robot,direction)
#    print(get_map().shape)
#    inp_grid = torch.Tensor(get_map())
#    world = World.fromPytorchTensor(inp_grid)
#    print(world.toString(), world.toString())
    write_dataset()
