import numpy as np


def make_3d():
    np.random.seed(1)
    D = 50
    N_STATES = D*D*D

    def idx_to_coord(idx):
        return idx // (D*D), (idx % (D*D)) // D, idx % D

    def coord_to_idx(coord):
        x, y, z = coord
        return x*D*D + y*D + z

    def neighbors_c(coord):
        x, y, z = coord
        return [
            ((x + 1) % D, y, z),
            ((x - 1) % D, y, z),
            (x, (y + 1) % D, z),
            (x, (y - 1) % D, z),
            (x, y, (z + 1) % D),
            (x, y, (z - 1) % D)
        ]

    def neighbors_i(idx):
        return list(map(coord_to_idx, neighbors_c(idx_to_coord(idx))))

    edge_dict = dict()
    for node in range(N_STATES):
        for neighbor in neighbors_i(node):
            edge_dict["{} {}".format(*sorted((node, neighbor)))] = np.random.randint(1, 101) / 100.

    with open('./data/3d.edgelist', 'w') as f:
        for k, v in edge_dict.items():
            f.write('{} {}\n'.format(k, v))


def make_3dd():
    np.random.seed(1)
    D = 50
    N_STATES = D*D*D

    def idx_to_coord(idx):
        return idx // (D*D), (idx % (D*D)) // D, idx % D

    def coord_to_idx(coord):
        x, y, z = coord
        return x*D*D + y*D + z

    # backward direction for each node is pruned
    def neighbors_c(coord):
        x, y, z = coord
        return [
            ((x + 1) % D, y, z),
            (x, (y + 1) % D, z),
            (x, y, (z + 1) % D),
        ]

    def neighbors_i(idx):
        return list(map(coord_to_idx, neighbors_c(idx_to_coord(idx))))

    edge_dict = dict()
    for node in range(N_STATES):
        for neighbor in neighbors_i(node):
            edge_dict["{} {}".format(node, neighbor)] = np.random.randint(1, 101) / 100.

    with open('./data/3dd.edgelist', 'w') as f:
        for k, v in edge_dict.items():
            f.write('{} {}\n'.format(k, v))


def make_taxi():
    np.random.seed(1)
    D = 25
    N_STATES = D*D * D*D + D*D

    def state_to_coord(state):
        return hash(state)

    # 5 Dim states: first two are gripper pos, next 2 are obj pos, if they are 
    # the same, last coord indicates whether object is held
    state_list = []
    sh_dict = {}
    hs_dict = {}
    for a in range(D):
        for b in range(D):
            for c in range(D):
                for d in range(D):
                    s = (a, b, c, d, 0)
                    sh_dict[s] = hash(s)
                    hs_dict[hash(s)] = s
                    state_list.append(s)
                    if a == c and b == d:
                        s = (a, b, c, d, 1)
                        sh_dict[s] = hash(s)
                        hs_dict[hash(s)] = s
                        state_list.append(s)
    assert(len(sh_dict) == len(hs_dict) == len(state_list) == N_STATES)

    hi_dict = {h:i for i, h in enumerate(hs_dict.keys())}
    ih_dict = {i:h for h, i in hi_dict.items()}

    D = 25

    def twoDneighbors(x, y):
        return list(set([(max(0, x-1), y), (x, max(0, y-1)), (min(D-1, x+1), y), (x, min(D-1, y+1))]) - set([(x, y)]))

    def idx_to_coord(idx):
        h = ih_dict[idx]
        s = hs_dict[h]
        return s

    def coord_to_idx(coord):
        h = sh_dict[coord]
        i = hi_dict[h]
        return i

    def neighbors_c(coord):
        ns = []
        c = coord
        # Gripper holding box
        if c[4] == 1:
            # Gripper drops box
            ns.append((c[0],c[1],c[2],c[3],0))
            # Gripper and box move together
            for n in twoDneighbors(c[0], c[1]):
                ns.append((n[0], n[1], n[0], n[1], 1))
        # Gripper not holding box
        elif c[4] == 0:
            # Gripper picks up box if at same position
            if c[0] == c[2] and c[1] == c[3]:
                ns.append((c[0],c[1],c[2],c[3],1))
            # Gripper moves, leaving the box where it is
            for n in twoDneighbors(c[0], c[1]):
                ns.append((n[0],n[1],c[2],c[3], 0))
        return ns

    def neighbors_i(idx):
        return list(map(coord_to_idx, neighbors_c(idx_to_coord(idx))))

    edge_dict = dict()
    for node in range(N_STATES):
        for neighbor in neighbors_i(node):
            edge_dict["{} {}".format(*sorted((node, neighbor)))] = np.random.randint(1, 101) / 100.

    with open('./data/taxi.edgelist', 'w') as f:
        for k, v in edge_dict.items():
            f.write('{} {}\n'.format(k, v))


def make_octagon():
    np.random.seed(1)
    D = 200
    N_STATES = D*D

    def idx_to_coord(idx):
        return idx // D, idx % D

    def coord_to_idx(coord):
        x, y = coord
        return x*D + y

    def neighbors_c(coord):
        x, y = coord
        return [
            ((x + 1) % D, y),
            (x, (y + 1) % D),
            ((x + 1) % D, (y + 1) % D),
            ((x + 1) % D, (y - 1) % D)
        ]

    def neighbors_i(idx):
        neighbors = neighbors_c(idx_to_coord(idx))
        return list(map(coord_to_idx, neighbors))

    edge_dict = dict()
    for node in range(N_STATES):
        for neighbor in neighbors_i(node):
            base_dist = np.random.randint(1, 101) / 100
            edge_dict["{} {}".format(node, neighbor)] = max(base_dist + np.random.normal(0, 0.2), 0.01)
            edge_dict["{} {}".format(neighbor, node)] = max(base_dist + np.random.normal(0, 0.2), 0.01)

    with open('./data/octagon.edgelist', 'w') as f:
        for k, v in edge_dict.items():
            f.write('{} {}\n'.format(k, v))


def make_traffic():
    np.random.seed(1)
    D = 200
    N_STATES = D*D

    def idx_to_coord(idx):
        return idx // D, idx % D

    def coord_to_idx(coord):
        x, y = coord
        return x*D + y

    def neighbors_c(coord):
        x, y = coord
        return [
            ((x + 1), y),
            (x, (y + 1)),
        ]

    def neighbors_i(idx):
        neighbors = neighbors_c(idx_to_coord(idx))
        neighbors = [x for x in neighbors if -1 not in x and D not in x]
        return list(map(coord_to_idx, neighbors))

    edge_dict = dict()
    for node in range(N_STATES):
        for neighbor in neighbors_i(node):
            base_dist = np.random.randint(1, 101) / 100
            edge_dict["{} {}".format(node, neighbor)] = max(base_dist + np.random.normal(0, 0.2), 0.01)
            edge_dict["{} {}".format(neighbor, node)] = max(base_dist + np.random.normal(0, 0.2), 0.01)

    with open('./data/traffic.edgelist', 'w') as f:
        for k, v in edge_dict.items():
            f.write('{} {}\n'.format(k, v))
