#!/usr/bin/env python
# -*- coding: utf-8 -*-

from d4rl.gym_minigrid.minigrid import MiniGridEnv, Grid, Goal, Lava
from d4rl.gym_minigrid.register import register
import numpy as np


class DeceptiveEnv(MiniGridEnv):
    """
    Classic 4 rooms gridworld environment.
    Can specify agent and goal position, if not it set at random.
    If the agent step onto the cell (10, *), it will be panelized.
    """

    def __init__(self, agent_pos=None, goal_pos=None, trap_penalty=-0.18, goal_reward=0.9):
        self._agent_default_pos = agent_pos
        self._goal_default_pos = goal_pos
        self.step_cost = -0.01
        self.trap_penalty = trap_penalty
        self.goal_reward = goal_reward
        self.trap_pos = None
        super().__init__(grid_size=19, max_steps=100)

    def _gen_grid(self, width, height):
        # Create the grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.horz_wall(0, 0)
        self.grid.horz_wall(0, height - 1)
        self.grid.vert_wall(0, 0)
        self.grid.vert_wall(width - 1, 0)

        room_w = width // 2
        room_h = height // 2

        # For each row of rooms
        for j in range(0, 2):

            # For each column
            for i in range(0, 2):
                xL = i * room_w
                yT = j * room_h
                xR = xL + room_w
                yB = yT + room_h

                # Bottom wall and door
                if i + 1 < 2:
                    self.grid.vert_wall(xR, yT, room_h)
                    pos = (xR, self._rand_int(yT + 1, yB))
                    self.grid.set(*pos, None)

                # Bottom wall and door
                if j + 1 < 2:
                    self.grid.horz_wall(xL, yB, room_w)
                    pos = (self._rand_int(xL + 1, xR), yB)
                    self.grid.set(*pos, None)

        # Randomize the player start position and orientation
        if self._agent_default_pos is not None:
            self.agent_pos = self._agent_default_pos
            self.grid.set(*self._agent_default_pos, None)
            self.agent_dir = self._rand_int(0, 4)  # assuming random start direction
        else:
            self.place_agent()

        if self._goal_default_pos is not None:
            goal = Goal()
            self.put_obj(goal, *self._goal_default_pos)
            goal.init_pos, goal.cur_pos = self._goal_default_pos
        else:
            self.place_obj(Goal())
        
        # place the trap
        self.trap_pos = self.get_trap_pos()
        trap = Lava()
        self.put_obj(trap, *self.trap_pos)

        self.mission = 'Reach the goal'

    def _reward(self):
        # return 0.9 * (1 - self.step_count / self.max_steps)
        return self.goal_reward

    def get_trap_pos(self):
        x = self.width // 2
        # y is the door position
        for y in range(self.height // 2, self.height):
            if self.grid.get(x, y) is None:
                break
        return x, y

    def step(self, action):
        # obs, reward, done, info = MiniGridEnv.step(self, action)

        self.step_count += 1

        reward = 0
        reward += self.step_cost
        done = False

        # Get the position in front of the agent
        fwd_pos = self.front_pos

        # Get the contents of the cell in front of the agent
        fwd_cell = self.grid.get(*fwd_pos)

        # Rotate left
        if action == self.actions.left:
            self.agent_dir -= 1
            if self.agent_dir < 0:
                self.agent_dir += 4

        # Rotate right
        elif action == self.actions.right:
            self.agent_dir = (self.agent_dir + 1) % 4

        # Move forward
        elif action == self.actions.forward:
            if fwd_cell == None or fwd_cell.can_overlap():
                self.agent_pos = fwd_pos
            if fwd_cell != None and fwd_cell.type == 'goal':
                done = True
                reward += self._reward()
            if fwd_cell != None and fwd_cell.type == 'lava':
                # done = True
                reward += self.trap_penalty

        # Pick up an object
        elif action == self.actions.pickup:
            if fwd_cell and fwd_cell.can_pickup():
                if self.carrying is None:
                    self.carrying = fwd_cell
                    self.carrying.cur_pos = np.array([-1, -1])
                    self.grid.set(*fwd_pos, None)

        # Drop an object
        elif action == self.actions.drop:
            if not fwd_cell and self.carrying:
                self.grid.set(*fwd_pos, self.carrying)
                self.carrying.cur_pos = fwd_pos
                self.carrying = None

        # Toggle/activate an object
        elif action == self.actions.toggle:
            if fwd_cell:
                fwd_cell.toggle(self, fwd_pos)

        # Done action (not used by default)
        elif action == self.actions.done:
            pass

        else:
            assert False, "unknown action"

        if self.step_count >= self.max_steps:
            done = True

        obs = self.gen_obs()
        return obs, reward, done, {}

