import os
import numpy as np
from enum import Enum
import itertools
from gym_minigrid.minigrid import COLOR_NAMES, DIR_TO_VEC

# Object types we are allowed to describe in language
OBJ_TYPES = ['box', 'ball', 'key', 'door']

# Object types we are allowed to describe in language
OBJ_TYPES_NOT_DOOR = list(filter(lambda t: t != 'door', OBJ_TYPES))

# Locations are all relative to the agent's starting position
LOC_NAMES = ['left', 'right', 'front', 'behind']

# Environment flag to indicate that done actions should be
# used by the verifier
use_done_actions = os.environ.get('BABYAI_DONE_ACTIONS', False)


# Types of instructions, used for enumeration. Does not include composite
# instructions
BASE_INSTR_TYPES = {}


class InvalidObjError(Exception):
    pass


def register_instr(name, arity):
    def decorator(cls):
        if name in BASE_INSTR_TYPES:
            raise ValueError(f"instr type {name} already exists")
        BASE_INSTR_TYPES[name] = (cls, arity)
        return cls
    return decorator


def dot_product(v1, v2):
    """
    Compute the dot product of the vectors v1 and v2.
    """

    return sum([i * j for i, j in zip(v1, v2)])


def pos_next_to(pos_a, pos_b):
    """
    Test if two positions are next to each other.
    The positions have to line up either horizontally or vertically,
    but positions that are diagonally adjacent are not counted.
    """

    xa, ya = pos_a
    xb, yb = pos_b
    d = abs(xa - xb) + abs(ya - yb)
    return d == 1


class ObjDesc:
    """
    Description of a set of objects in an environment
    """

    def __init__(self, type, color=None, loc=None):
        assert type in [None, *OBJ_TYPES], type
        assert color in [None, *COLOR_NAMES], color
        assert loc in [None, *LOC_NAMES], loc

        self.color = color
        self.type = type
        self.loc = loc

        # Set of objects possibly matching the description
        # as tuples, (obj, pos)
        self.objs = set()

    def __repr__(self):
        return "{} {} {}".format(self.color, self.type, self.loc)

    def surface(self, env=None):
        """
        Generate a natural language representation of the object description
        """
        if env is not None:
            # Verify object exists in environment
            self.find_matching_objs(env)
            assert len(self.objs) > 0, "no object matching description"

        if self.type:
            s = str(self.type)
        else:
            s = 'object'

        if self.color:
            s = self.color + ' ' + s

        if self.loc:
            if self.loc == 'front':
                s = s + ' in front of you'
            elif self.loc == 'behind':
                s = s + ' behind you'
            else:
                s = s + ' on your ' + self.loc

        # Singular vs plural
        if len(self.objs) > 1:
            s = 'a ' + s
        else:
            s = 'the ' + s

        return s

    def find_matching_objs(self, env, use_location=True):
        """
        Find the set of objects matching the description and their positions.
        When use_location is False, we only update the positions of already tracked objects, without taking into account
        the location of the object. e.g. A ball that was on "your right" initially will still be tracked as being "on
        your right" when you move.
        """

        #  if use_location:
        self.objs = set()

        agent_room = env.room_from_pos(*env.agent_pos)

        for i in range(env.grid.width):
            for j in range(env.grid.height):
                cell = env.grid.get(i, j)
                if cell is None:
                    continue

                # Prev minigrid only reconstructs obj set if use_location=True, but here we reconstruct it always.
                #  if not use_location:
                    # we should keep tracking the same objects initially tracked only
                    #  already_tracked = any([cell is obj for obj in self.obj_set])
                    #  if not already_tracked:
                        #  continue

                # Check if object's type matches description
                if self.type is not None and cell.type != self.type:
                    continue

                # Check if object's color matches description
                if self.color is not None and cell.color != self.color:
                    continue

                # Check if object's position matches description
                if use_location and self.loc in ["left", "right", "front", "behind"]:
                    # Locations apply only to objects in the same room
                    # the agent starts in
                    if not agent_room.pos_inside(i, j):
                        continue

                    # Direction from the agent to the object
                    v = (i - env.agent_pos[0], j - env.agent_pos[1])

                    # (d1, d2) is an oriented orthonormal basis
                    d1 = DIR_TO_VEC[env.agent_dir]
                    d2 = (-d1[1], d1[0])

                    # Check if object's position matches with location
                    pos_matches = {
                        "left": dot_product(v, d2) < 0,
                        "right": dot_product(v, d2) > 0,
                        "front": dot_product(v, d1) > 0,
                        "behind": dot_product(v, d1) < 0
                    }

                    if not (pos_matches[self.loc]):
                        continue

                #  if use_location:
                self.objs.add((cell, (i, j)))

        return self.objs

    @classmethod
    def enumerate(cls, use_location=False):
        """Enumerate all possible object descriptions."""
        if use_location:
            locs = [None, *LOC_NAMES]
        else:
            locs = [None]
        # NOTE: for now, obj_type cannot be None.
        for obj_type in [*OBJ_TYPES]:
            for color in [None, *COLOR_NAMES]:
                if obj_type is None and color is None:
                    continue
                for loc in locs:
                    yield cls(obj_type, color=color, loc=loc)


class Instr:
    """
    Base class for all instructions in the baby language
    """

    def __init__(self):
        self.env = None

    def surface(self, env):
        """
        Produce a natural language representation of the instruction
        """

        raise NotImplementedError

    def reset_verifier(self, env):
        """
        Must be called at the beginning of the episode
        """

        self.env = env

    def verify(self, action):
        """
        Verify if the task described by the instruction is incomplete,
        complete with success or failed. The return value is a string,
        one of: 'success', 'failure' or 'continue'.
        """

        raise NotImplementedError

    def is_achievable(self):
        """
        Return whether the instruction is achievable in the assigned environment.
        """
        raise NotImplementedError

    def update_objs_poss(self):
        """
        Update the position of objects present in the instruction if needed
        """
        potential_objects = ('desc', 'desc_move', 'desc_fixed')
        for attr in potential_objects:
            if hasattr(self, attr):
                getattr(self, attr).find_matching_objs(self.env, use_location=False)

    @staticmethod
    def enumerate():
        """Yield all base (i.e. non-composite) instructions, in order of increasing arity."""
        obj_descs = list(ObjDesc.enumerate())
        instrs_by_arity = sorted(BASE_INSTR_TYPES.values(), key=lambda x: x[1])
        for (instr_cls, instr_arity) in instrs_by_arity:
            obj_args = itertools.product(obj_descs, repeat=instr_arity)
            for obj_arg in obj_args:
                try:
                    instr = instr_cls(*obj_arg)
                except InvalidObjError:
                    continue
                yield instr


@register_instr("null", arity=0)
class NullInstr(Instr):
    """
    An instr that is always available, but can never be achieved.
    """
    def surface(self, env):
        return '<NULL>'

    def is_achievable(self):
        return True

    def verify(self, action):
        return "failure"


class ActionInstr(Instr):
    """
    Base class for all action instructions (clauses)
    """

    def __init__(self):
        super().__init__()

        # Indicates that the action was completed on the last step
        self.lastStepMatch = False

    def verify(self, action):
        """
        Verifies actions, with and without the done action.
        """

        if not use_done_actions:
            return self.verify_action(action)

        if action == self.env.actions.done:
            if self.lastStepMatch:
                return 'success'
            return 'failure'

        res = self.verify_action(action)
        self.lastStepMatch = (res == 'success')

    def verify_action(self):
        """
        Each action instruction class should implement this method
        to verify the action.
        """

        raise NotImplementedError


@register_instr("open", arity=1)
class OpenInstr(ActionInstr):
    def __init__(self, obj_desc, strict=False):
        super().__init__()
        if obj_desc.type not in {"door", "box"}:
            raise InvalidObjError(f"wrong object type {obj_desc.type} for open instr")
        self.desc = obj_desc
        self.strict = strict

    def surface(self, env):
        return 'open ' + self.desc.surface(env)

    def is_achievable(self):
        """
        FIXME: this currently ignores whether a door is locked and whether a
        key is able to unlock it.
        """
        return len(self.desc.objs) > 0

    def reset_verifier(self, env):
        super().reset_verifier(env)

        # Identify set of possible matching objects in the environment
        self.desc.find_matching_objs(env)

    def verify_action(self, action):
        # Only verify when the toggle action is performed
        if action != self.env.actions.toggle:
            return 'continue'

        # Get the contents of the cell in front of the agent
        front_cell = self.env.grid.get(*self.env.front_pos)

        if self.desc.type == "door":
            if matches(front_cell, self.desc) and front_cell.is_open:
                return 'success'
        else:
            for box, _ in self.desc.objs:
                # If front cell contents are what was in the box originally
                if front_cell and front_cell is box.contains and self not in box.instr_opened:
                    box.instr_opened.add(self)
                    return 'success'

        # If in strict mode and the wrong door is opened, failure
        if self.strict:
            if front_cell and front_cell.type == 'door':
                return 'failure'

        return 'continue'


@register_instr("goto", arity=1)
class GoToInstr(ActionInstr):
    """
    Go next to (and look towards) an object matching a given description
    eg: go to the door
    """

    def __init__(self, obj_desc):
        super().__init__()
        self.desc = obj_desc

    def surface(self, env):
        return 'go to ' + self.desc.surface(env)

    def is_achievable(self):
        return len(self.desc.objs) > 0

    def reset_verifier(self, env):
        super().reset_verifier(env)

        # Identify set of possible matching objects in the environment
        self.desc.find_matching_objs(env)

    def verify_action(self, action):
        if action in {self.env.actions.toggle, self.env.actions.drop, self.env.actions.pickup}:
            return 'continue'

        # For each object position
        for obj, pos in self.desc.objs:
            # If the agent is next to (and facing) the object
            if np.array_equal(pos, self.env.front_pos):
                # if object is destroyed (box), doesn't count
                if obj.type == "box" and obj.destroyed:
                    return 'continue'
                return 'success'

        return 'continue'


@register_instr("pickup", arity=1)
class PickupInstr(ActionInstr):
    """
    Pick up an object matching a given description
    eg: pick up the grey ball
    """

    def __init__(self, obj_desc, strict=False):
        super().__init__()
        if obj_desc.type == 'door':
            raise InvalidObjError(f"wrong object type {obj_desc.type} for pickup instr")
        self.desc = obj_desc
        self.strict = strict

    def surface(self, env):
        return 'pick up ' + self.desc.surface(env)

    def is_achievable(self):
        return len(self.desc.objs) > 0

    def reset_verifier(self, env):
        super().reset_verifier(env)

        # Object previously being carried
        self.preCarrying = None

        # Identify set of possible matching objects in the environment
        self.desc.find_matching_objs(env)

    def verify_action(self, action):
        # To keep track of what was carried at the last time step
        preCarrying = self.preCarrying
        self.preCarrying = self.env.carrying

        # Only verify when the pickup action is performed
        if action != self.env.actions.pickup:
            return 'continue'

        if preCarrying is None and self.env.carrying is not None:
            # Check if carried obj matches current obj
            if matches(self.env.carrying, self.desc):
                return 'success'

        # If in strict mode and the wrong door object is picked up, failure
        if self.strict:
            if self.env.carrying:
                return 'failure'

        self.preCarrying = self.env.carrying

        return 'continue'


@register_instr("putnext", arity=2)
class PutNextInstr(ActionInstr):
    """
    Put an object next to another object
    eg: put the red ball next to the blue key
    """

    def __init__(self, obj_move, obj_fixed, strict=False):
        super().__init__()
        if obj_move.type == 'door':
            raise InvalidObjError(f"wrong move object type {obj_move.type} for putnext instr")
        self.desc_move = obj_move
        self.desc_fixed = obj_fixed
        self.strict = strict

    def surface(self, env):
        return 'put ' + self.desc_move.surface(env) + ' next to ' + self.desc_fixed.surface(env)

    def is_achievable(self):
        return len(self.desc_move.objs) > 0 and len(self.desc_fixed.objs) > 0

    def reset_verifier(self, env):
        super().reset_verifier(env)

        # Object previously being carried
        self.preCarrying = None

        # Identify set of possible matching objects in the environment
        self.desc_move.find_matching_objs(env)
        self.desc_fixed.find_matching_objs(env)

    def objs_next(self):
        """
        Check if the objects are next to each other
        This is used for rejection sampling
        """

        for obj_a, _ in self.desc_move.objs:
            pos_a = obj_a.cur_pos

            for _, pos_b in self.desc_fixed.objs:
                if pos_next_to(pos_a, pos_b):
                    return True
        return False

    def verify_action(self, action):
        # To keep track of what was carried at the last time step
        preCarrying = self.preCarrying
        self.preCarrying = self.env.carrying

        # In strict mode, picking up the wrong object fails
        if self.strict:
            if action == self.env.actions.pickup and self.env.carrying:
                return 'failure'

        # Only verify when the drop action is performed
        if action != self.env.actions.drop:
            return 'continue'

        if matches(preCarrying, self.desc_move):
        #  for obj_a, _ in self.desc_move.objs:
            #  if preCarrying is not obj_a:
                #  continue
            pos_a = preCarrying.cur_pos

            for _, pos_b in self.desc_fixed.objs:
                if pos_next_to(pos_a, pos_b):
                    return 'success'

        return 'continue'


class SeqInstr(Instr):
    """
    Base class for sequencing instructions (before, after, and)
    """

    def __init__(self, instr_a, instr_b, strict=False):
        assert isinstance(instr_a, ActionInstr) or isinstance(instr_a, AndInstr)
        assert isinstance(instr_b, ActionInstr) or isinstance(instr_b, AndInstr)
        self.instr_a = instr_a
        self.instr_b = instr_b
        self.strict = strict


class BeforeInstr(SeqInstr):
    """
    Sequence two instructions in order:
    eg: go to the red door then pick up the blue ball
    """

    def surface(self, env):
        return self.instr_a.surface(env) + ', then ' + self.instr_b.surface(env)

    def reset_verifier(self, env):
        super().reset_verifier(env)
        self.instr_a.reset_verifier(env)
        self.instr_b.reset_verifier(env)
        self.a_done = False
        self.b_done = False

    def verify(self, action):
        if self.a_done == 'success':
            self.b_done = self.instr_b.verify(action)

            if self.b_done == 'failure':
                return 'failure'

            if self.b_done == 'success':
                return 'success'
        else:
            self.a_done = self.instr_a.verify(action)
            if self.a_done == 'failure':
                return 'failure'

            if self.a_done == 'success':
                return self.verify(action)

            # In strict mode, completing b first means failure
            if self.strict:
                if self.instr_b.verify(action) == 'success':
                    return 'failure'

        return 'continue'


class AfterInstr(SeqInstr):
    """
    Sequence two instructions in reverse order:
    eg: go to the red door after you pick up the blue ball
    """

    def surface(self, env):
        return self.instr_a.surface(env) + ' after you ' + self.instr_b.surface(env)

    def reset_verifier(self, env):
        super().reset_verifier(env)
        self.instr_a.reset_verifier(env)
        self.instr_b.reset_verifier(env)
        self.a_done = False
        self.b_done = False

    def verify(self, action):
        if self.b_done == 'success':
            self.a_done = self.instr_a.verify(action)

            if self.a_done == 'success':
                return 'success'

            if self.a_done == 'failure':
                return 'failure'
        else:
            self.b_done = self.instr_b.verify(action)
            if self.b_done == 'failure':
                return 'failure'

            if self.b_done == 'success':
                return self.verify(action)

            # In strict mode, completing a first means failure
            if self.strict:
                if self.instr_a.verify(action) == 'success':
                    return 'failure'

        return 'continue'


class AndInstr(SeqInstr):
    """
    Conjunction of two actions, both can be completed in any other
    eg: go to the red door and pick up the blue ball
    """

    def __init__(self, instr_a, instr_b, strict=False):
        assert isinstance(instr_a, ActionInstr)
        assert isinstance(instr_b, ActionInstr)
        super().__init__(instr_a, instr_b, strict)

    def surface(self, env):
        return self.instr_a.surface(env) + ' and ' + self.instr_b.surface(env)

    def reset_verifier(self, env):
        super().reset_verifier(env)
        self.instr_a.reset_verifier(env)
        self.instr_b.reset_verifier(env)
        self.a_done = False
        self.b_done = False

    def verify(self, action):
        if self.a_done != 'success':
            self.a_done = self.instr_a.verify(action)

        if self.b_done != 'success':
            self.b_done = self.instr_b.verify(action)

        if use_done_actions and action is self.env.actions.done:
            if self.a_done == 'failure' and self.b_done == 'failure':
                return 'failure'

        if self.a_done == 'success' and self.b_done == 'success':
            return 'success'

        return 'continue'


# List of all base instrs
INSTRS = list(Instr.enumerate())


def matches(obj, desc):
    return (obj is not None) and (
        (desc.type is None or obj.type == desc.type) and
        (desc.color is None or obj.color == desc.color)
    )
