import abc
import numpy as np
import gym

class Condition(abc.ABC):
    
    @abc.abstractmethod
    def validate(self, env):
        raise NotImplementedError()

    def __and__(self, obj):
        return AndCondition(self, obj)

    def __or__(self, obj):
        return OrCondition(self, obj)

class AndCondition(Condition):

    def __init__(self, *conditions):
        self.conditions = conditions

    def validate(self, env):
        return all(c.validate(env) for c in self.conditions)

class OrCondition(Condition):
    
    def __init__(self, *conditions):
        self.conditions = conditions

    def validate(self, env):
        return any(c.validate(env) for c in self.conditions)

class CanMoveToCondition(Condition):

    def __init__(self, name, position_update):
        self.name = name
        self.position_update = position_update

    def validate(self, env):
        if not env.agent_pos.size:
            return True
        fwd_pos = env.agent_pos + self.position_update
        fwd_cell = env.grid.get(*fwd_pos)
        return fwd_cell == None or fwd_cell.can_overlap()

class GridWorldActionPreConditionWrapper(gym.core.Wrapper):

    def __init__(self, env) -> None:
        super().__init__(env)

        self.up_preconditions = AndCondition(
            CanMoveToCondition("up", np.array((0, -1))),
            # CanMoveToCondition("up", np.array((-1, 0))),
        )
        self.right_preconditions = AndCondition(
            CanMoveToCondition("right", np.array((1, 0))),
            # CanMoveToCondition("right", np.array((0, 1))),
        )
        self.down_preconditions = AndCondition(
            CanMoveToCondition("down", np.array((0, 1))),
            # CanMoveToCondition("down", np.array((1, 0))),
        )
        self.left_preconditions = AndCondition(
            CanMoveToCondition("left", np.array((-1, 0))),
            # CanMoveToCondition("left", np.array((0, -1))),
        )

    def _get_agent_pos(self, state):
        return np.array(np.where(state == self.object_id("agent"))).flatten()

    def is_applicable(self, state, action) -> float:
        """ Returns the probability of being applicable """
        env = self.unwrapped.copy()
        env.agent_pos = self._get_agent_pos(state)

        class Default:
            def validate(self, env):
                return True

        is_applicable = {
            self.actions.up: self.up_preconditions,
            self.actions.right: self.right_preconditions,
            self.actions.down: self.down_preconditions,
            self.actions.left: self.left_preconditions,
        }.get(action, Default()).validate(env)

        return float(is_applicable)

class CanOpenDoorCondition(Condition):

    def __init__(self, name, position_update):
        self.name = name
        self.position_update = position_update

    def validate(self, env):
        if not env.agent_pos.size:
            return True
        key_in_environment = True if 'KY' in str(env.unwrapped) else False
        if key_in_environment:
            return False
        fwd_pos = env.agent_pos + self.position_update
        fwd_cell = env.grid.get(*fwd_pos)
        if fwd_cell is not None and fwd_cell.type == 'door' and not key_in_environment and not fwd_cell.is_open:
            return True
        else:
            return False


class CanPickupKeyCondition(Condition):

    def __init__(self, name, position_update):
        self.name = name
        self.position_update = position_update

    def validate(self, env):
        if not env.agent_pos.size:
            return True
        fwd_pos = env.agent_pos + self.position_update
        fwd_cell = env.grid.get(*fwd_pos)
        if fwd_cell is None:
            return False
        return fwd_cell.can_pickup()

class DoorKeyActionPreconditionWrapper(GridWorldActionPreConditionWrapper):
    def __init__(self, env, tile_size=42) -> None:
        super().__init__(env)

        self._tile_size = tile_size

        self.pickup_key_precondition = AndCondition(
            CanPickupKeyCondition("pickup", np.array((0,-1)))
        )

        self.open_door_precondition = AndCondition(
            CanOpenDoorCondition("open", np.array((1, 0)))
        )

    def _get_agent_pos(self, state):
        state = state[:,self._tile_size//2::self._tile_size,self._tile_size//2::self._tile_size]
        pos = list(zip(*np.where(
            (state[0,...] == 255) & (state[1,...] == 0) & (state[2,...] == 0)
        )))
        return np.array(pos).flatten()

    def _is_key_present(self, state):
        state = state[:,self._tile_size//2::self._tile_size,self._tile_size//2::self._tile_size]
        r, g, b = state[0,1,1], state[1,1,1], state[2,1,1]
        if r == 255 and g == 255 and b == 0:
            return True
        elif (r == 0 or r == 255) and g == 0 and b == 0:
            return False
        else:
            raise ValueError("Unkown color code", r, g, b)

    def _is_door_closed(self, state):
        state = state[:,self._tile_size//2::self._tile_size,self._tile_size//2::self._tile_size]
        r, g, b = state[0,2,2], state[1,2,2], state[2,2,2]
        # if r == 161 and g == 161 and b == 0:
        if r == 114 and g == 114 and b == 0:
            return True
        elif (r == 0 or r == 255) and g == 0 and b == 0:
            return False
        else:
            raise ValueError("Unkown color code", r, g, b)

    def is_applicable(self, state, action) -> float:
        """ Returns the probability of being applicable """
        env = self.unwrapped.copy()

        ## !!!! WARNING !!! this code is specific to the MiniGrid-SimpleDoorKey-5x5-v0 environment
        ## TODO for later, make this more generic
        env.reset()
        if not self._is_key_present(state):
            # remove the key
            env.step(self.actions.up)
            env.step(self.actions.pickup)
            env.step(self.actions.down)
            env.step(self.actions.down)
        
        if not self._is_door_closed(state):
            # open the door
            assert not self._is_key_present(state)
            env.step(self.actions.up)
            env.step(self.actions.open)

        env.agent_pos = self._get_agent_pos(state)

        class Default:
            def validate(self, env):
                return False

        is_applicable = {
            self.actions.up: self.up_preconditions,
            self.actions.right: self.right_preconditions,
            self.actions.down: self.down_preconditions,
            self.actions.left: self.left_preconditions,
            self.actions.pickup: self.pickup_key_precondition,
            self.actions.open: self.open_door_precondition
        }.get(action, Default()).validate(env)

        return float(is_applicable)

class CNNGridWorldActionPreConditionWrapper(GridWorldActionPreConditionWrapper):

    def __init__(self, env, tile_size=30) -> None:
        super().__init__(env)

        self._tile_size = tile_size

    def _get_agent_pos(self, state):
        state = state[:,self._tile_size//2::self._tile_size,self._tile_size//2::self._tile_size]
        pos = list(zip(*np.where(
            (state[0,...] == 255) & (state[1,...] == 0) & (state[2,...] == 0)
        )))
        return np.array(pos).flatten()
