import random
import warnings
from typing import Optional

import gymnasium as gym
import numpy as np
from gymnasium import spaces

# gym warnings are annoying
warnings.filterwarnings("ignore")


class GridWorld(gym.Env):
    metadata = {"render_mode": ["rgb_array"], "render_fps": 1}

    def __init__(
        self, size=5, goal_pos=None, render_mode=None, num_steps=15,
    ):
        self.size = size
        self.num_steps = num_steps
        # TODO: check with (x,y) obs
        self.observation_space = spaces.Discrete(self.size**2)
        self.action_space = spaces.Discrete(5)

        self.action_to_direction = {
            0: np.array((0, 0), dtype=np.float32),  # noop
            1: np.array((0, 1), dtype=np.float32),  # up
            2: np.array((1, 0), dtype=np.float32),  # right
            3: np.array((0, -1), dtype=np.float32),  # down
            4: np.array((-1, 0), dtype=np.float32),  # left
        }

        self.starting_state = (0.0, 0.0)  # , (self.num_cells-1, 0)]#,
        self.states = [(x, y) for y in np.arange(0, size) for x in np.arange(0, size)]
        self.possible_goals = self.states.copy()
        self.possible_goals.remove(self.starting_state)
        self.possible_goals.remove((0, 1))
        self.possible_goals.remove((1, 1))
        self.possible_goals.remove((1, 0))

        self.step_count = 0

        if goal_pos is not None:
            self.goal_pos = np.asarray(goal_pos)
            assert self.goal_pos.ndim == 1
        else:
            self.goal_pos = self.generate_goal_pos()

        self.agent_pos = np.array(self.starting_state, dtype=np.float32)
        self.render_mode = render_mode

    def generate_goal_pos(self):
        return np.array(random.sample(self.possible_goals, 1))[0]

    def pos_to_state(self, pos):
        return int(pos[0] * self.size + pos[1])

    def state_to_pos(self, state):
        return np.array(divmod(state, self.size))

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        self.agent_pos = np.array(self.starting_state, dtype=np.float32)
        self.step_count = 0
        return self.pos_to_state(self.agent_pos), {}

    def step(self, action):
        self.agent_pos = np.clip(
            self.agent_pos + self.action_to_direction[action], 0, self.size - 1
        )
        self.step_count += 1

        reward = 1.0 if np.array_equal(self.agent_pos, self.goal_pos) else 0.0
        terminated = self.step_count >= self.num_steps

        return self.pos_to_state(self.agent_pos), reward, terminated, False, {}

    def render(self) -> Optional[np.ndarray]:
        if self.render_mode == "rgb_array":
            # Create a grid representing the dark room
            grid = np.full(
                (self.size, self.size, 3), fill_value=(255, 255, 255), dtype=np.uint8
            )
            grid[self.goal_pos[0], self.goal_pos[1]] = (255, 0, 0)
            grid[int(self.agent_pos[0]), int(self.agent_pos[1])] = (0, 255, 0)
            return grid


def train_test_goals_gw(grid_size, num_train_goals, seed):
    states = [(x, y) for y in np.arange(0, grid_size) for x in np.arange(0, grid_size)]
    possible_goals = states.copy()
    possible_goals.remove((0, 0))
    possible_goals.remove((0, 1))
    possible_goals.remove((1, 1))
    possible_goals.remove((1, 0))

    possible_goals = np.array(possible_goals)
    return possible_goals, possible_goals


if __name__ == "__main__":
    train_goal, test_goals = train_test_goals_gw(5, 0, 0)
    env = GridWorld(goal_pos=test_goals[0])
    print(test_goals)
    while True:
        obs, _ = env.reset()
        print("Goal:", env.goal_pos)
        done = False
        while not done:
            print("State:", env.state_to_pos(obs))
            action = int(input("Action:"))
            obs, reward, done, _, _ = env.step(action)
            print("Reward", reward)
