import argparse

import numpy as np
import torch as t
from tqdm import trange

from ...simulators.navigation import Navigation

parser = argparse.ArgumentParser()
parser.add_argument("--save", required=True)
args = parser.parse_args()


def compute_V(free_space, goals):
    transition = (
        t.tensor(
            [
                [[0, 0, 0], [0, 1, 0], [0, 0, 0]],
                [[0, 0, 0], [1, 0, 0], [0, 0, 0]],
                [[0, 1, 0], [0, 0, 0], [0, 0, 0]],
                [[0, 0, 0], [0, 0, 1], [0, 0, 0]],
                [[0, 0, 0], [0, 0, 0], [0, 1, 0]],
            ]
        )
        .float()
        .reshape(5, 1, 3, 3)
    )

    R = goals
    V = t.zeros_like(R)
    i = 0
    while True:
        i += 1
        Q = t.nn.functional.conv2d(V * 0.9, transition, bias=None, stride=1, padding=1) + R
        V_prev = V
        V = t.max(Q, dim=1, keepdim=True).values
        V = V * free_space

        if t.max(t.abs(V - V_prev)) < 0.001:
            break

    return V.squeeze(1)


def compute_policy(V):
    transition = (
        t.tensor(
            [
                [[0, 0, 0], [0, 0, 1], [0, 0, 0]],  # 0: RIGHT
                [[0, 0, 0], [0, 0, 0], [0, 1, 0]],  # 1: DOWN
                [[0, 0, 0], [1, 0, 0], [0, 0, 0]],  # 2: LEFT
                [[0, 1, 0], [0, 0, 0], [0, 0, 0]],  # 3: UP
                [[0, 0, 0], [0, 1, 0], [0, 0, 0]],  # 4: STAY
            ]
        )
        .float()
        .reshape(5, 1, 3, 3)
    )

    Q = t.nn.functional.conv2d(V.unsqueeze(1), transition, bias=None, stride=1, padding=1)

    policy = t.argmax(Q, dim=1)
    return policy


data = {
    "states": [],
    "actions": [],
    "rewards": [],
    "values": [],
}
for i in trange(10000):
    simulator = Navigation()
    simulator.reset()

    t_state = simulator.state_tensor()[0]

    free_space = (t_state[0] == 0).float().view(1, 1, simulator.N, simulator.N)
    goal = t_state[1].float().view(1, 1, simulator.N, simulator.N)

    with t.inference_mode():
        V = compute_V(free_space, goal)
        policy = compute_policy(V).view(-1).numpy()

    terminal = False
    states, actions, rewards = [], [], []
    while not terminal:
        states.append(simulator.state_tensor().squeeze().bool().cpu().numpy())
        walls, goal, robot = simulator.state
        state, reward, terminal = simulator.step(policy[robot])

        actions.append(policy[robot])
        rewards.append(reward)

    # Add final state, a random action and 0 reward
    assert simulator.is_solved()

    states.append(simulator.state_tensor().squeeze().bool().cpu().numpy())
    actions.append(np.random.randint(simulator.n_actions))
    rewards.append(0)

    data["states"].append(np.asarray(states))
    data["actions"].append(np.asarray(actions))
    data["rewards"].append(np.asarray(rewards))
    data["values"].append(np.flip(np.cumsum(np.flip(rewards))))

t.save(
    data,
    args.save,
)
