import gym
import numpy as np
from wrappers import FlattenWrapper
import gridworld
from models import NonLinearModel
import torch
import torch.nn.functional as F


def build_dataset(env, a_forwards_or_backwards):
    state_action_to_state = {}
    prev_obs = env.reset()
    inputs = []
    outputs = []
    for _ in range(int(1e4)):
        action = env.action_space.sample()
        obs, _, _, _ = env.step(action)
        key_string = "forward" + str(prev_obs) + str(action)
        action_tensor = (np.ones_like(obs) * action) / 5
        # if key_string not in state_action_to_state:
        state_action_to_state[key_string] = obs
        if a_forwards_or_backwards == "forwards":
            inputs.append(np.concatenate((prev_obs, action_tensor)))
            outputs.append(obs)
        elif a_forwards_or_backwards == "backwards":
            inputs.append(np.concatenate((obs, action_tensor)))
            outputs.append(prev_obs)
        # assert np.array_equal(state_action_to_state[key_string], obs)

        prev_obs = obs

    return torch.FloatTensor(np.array(inputs)), torch.FloatTensor(np.array(outputs))


env = FlattenWrapper(gym.make("GridEnv-v1"))
forwards_or_backwards = ["forwards", "backwards"]
for a_forwards_or_backwards in forwards_or_backwards:
    model = NonLinearModel(input_size=882, output_size=441, hidden_size=512)
    opt = torch.optim.Adam(model.parameters(), 1e-4)
    inputs, outputs = build_dataset(env, a_forwards_or_backwards)
    model = model.to("cuda:0")
    inputs = inputs.to("cuda:0")
    outputs = outputs.to("cuda:0")
    for _ in range(200000):
        rand_idx = np.random.randint(0, len(inputs) - 64)
        opt.zero_grad()
        loss = F.mse_loss(
            model(inputs[rand_idx : rand_idx + 64]), outputs[rand_idx : rand_idx + 64]
        )
        print(loss)
        loss.backward()
        opt.step()
    torch.save(
        model.state_dict(),
        f"model_{a_forwards_or_backwards}.pt",
    )
