from minigrid.envs.babyai.goto import (
    GoToRedBall,
    GoToRedBallGrey
)
from minigrid.envs.babyai.core.verifier import GoToInstr, ObjDesc

class BabyAIGoToRedBall(GoToRedBall):
    def __init__(self, room_size=8, num_dists=7, **kwargs):
        self.num_dists = num_dists
        super().__init__(room_size=room_size, num_dists=num_dists, **kwargs)

    def gen_mission(self):
        if self.num_dists > 0:
            self.place_agent()
            obj, _ = self.add_object(0, 0, "ball", "red")
            self.add_distractors(num_distractors=self.num_dists, all_unique=False)

            # Make sure no unblocking is required
            self.check_objs_reachable()

            self.instrs = GoToInstr(ObjDesc(obj.type, obj.color))
        else:
            super().gen_mission()

    def __str__(self):
        if self.agent_pos:
            return super().__str__()
        else:
            return 'str(env) invoked before the grid is initialized.'

class BabyAIGoToRedBallGrey(GoToRedBallGrey):
    def __init__(self, room_size=8, num_dists=7, **kwargs):
        self.num_dists = num_dists
        super().__init__(room_size=room_size, num_dists=num_dists, **kwargs)

    def gen_mission(self):
        if self.num_dists > 0:
            self.place_agent()
            obj, _ = self.add_object(0, 0, "ball", "red")
            dists = self.add_distractors(num_distractors=self.num_dists, all_unique=False)

            for dist in dists:
                dist.color = "grey"

            # Make sure no unblocking is required
            self.check_objs_reachable()

            self.instrs = GoToInstr(ObjDesc(obj.type, obj.color))
        else:
            super().gen_mission()

    def __str__(self):
        if self.agent_pos:
            return super().__str__()
        else:
            return 'str(env) invoked before the grid is initialized.'
