from __future__ import annotations

import copy
import random
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable, Literal, NewType, overload

import numpy as np
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers.util import cos_sim

_IdInt = NewType("_IdInt", int)
_IdStr = NewType("_IdStr", str)
_Id = _IdInt | _IdStr
_Name = NewType("_Name", str)


def _id(id: int | str) -> _Id:
    if isinstance(id, int):
        return _IdInt(id)
    elif isinstance(id, str):
        return _IdStr(id)
    else:
        raise TypeError(f"Expected int or str, got {type(id)}")


def compute_similarity(
    instruction: torch.Tensor, kg_triples: torch.Tensor
) -> np.ndarray:
    cos_sim_results = cos_sim(
        instruction.detach().numpy(), kg_triples.detach().numpy()
    )
    return cos_sim_results.cpu().numpy()


class Relation(Enum):
    ON = 0
    INSIDE = 1
    CLOSE = 2
    HOLD = 3
    IS = 4
    ADJACENT = 5
    FACING = 6
    BETWEEN = 7

    # For CARLA
    LEFT = 8
    RIGHT = 9
    STRAIGHT = 10

    @classmethod
    def all(cls):
        return list(Relation)

    def __repr__(self):
        return self.name.lower()

    def __str__(self):
        return self.__repr__()


class State(Enum):
    CLOSED = 0
    OPEN = 1
    ON = 2
    OFF = 3
    PLUGGED_IN = 4
    PLUGGED_OUT = 5
    DIRTY = 5
    CLEAN = 6

    CALLED = 7
    QUITE = 8
    WAITING = 9
    DELIVERED = 10
    PACKED = 11

    def opposite(self) -> State:
        match self:
            case State.CLOSED:
                return State.OPEN
            case State.OPEN:
                return State.CLOSED
            case State.ON:
                return State.OFF
            case State.OFF:
                return State.ON
            case State.PLUGGED_IN:
                return State.PLUGGED_OUT
            case State.PLUGGED_OUT:
                return State.PLUGGED_IN
            case State.DIRTY:
                return State.CLEAN
            case State.CLEAN:
                return State.DIRTY
        raise NotImplementedError(
            f"State {self} does not have an opposite state."
        )

    def __invert__(self):
        return self.opposite()

    def __str__(self):
        return self.name.lower()


class Room(Enum):
    KITCHEN = 0
    LIVINGROOM = 1
    BATHROOM = 2
    BEDROOM = 3

    def __str__(self):
        return self.name.lower()

    @property
    def id(self) -> _IdStr:
        return _IdStr(str(self))

    @staticmethod
    def all():
        return [str(e).lower() for e in Room.__members__.keys()]


@dataclass
class GraphNode:
    id: _Id
    name: _Name

    def __str__(self):
        return self.name

    @staticmethod
    def from_dict(source: dict):
        try:
            return GraphNode(source["id"], source["class_name"])
        except:
            raise ValueError(
                "GraphNode must have 'id' and 'class_name' keys in the dictionary."
            )


def unique[T](sequence: list[T]):
    seen = set()
    return [x for x in sequence if not (x in seen or seen.add(x))]


_FromNode = GraphNode | Room
_ToNode = GraphNode | State | Room


@dataclass
class GraphEdge:
    from_node: _FromNode
    relation: Relation
    to_node: _ToNode
    timestep: int = field(compare=False)

    def conflict(self, other: GraphEdge):
        if self.timestep >= other.timestep:
            return False

        if (
            isinstance(self.to_node, GraphNode)
            and self.to_node.name == "character"
        ):
            return True

        if self.relation == Relation.ADJACENT:
            return True

        if self.from_node != other.from_node or self.relation != other.relation:
            return False

        match self.relation:
            case Relation.INSIDE:
                if (
                    isinstance(self.to_node, Room)
                    and isinstance(other.to_node, Room)
                    and self.to_node != other.to_node
                ):
                    return True

                if (
                    not isinstance(self.to_node, Room)
                    and not isinstance(other.to_node, Room)
                    and self.to_node != other.to_node
                ):
                    return True

            case Relation.ON:
                if self.to_node != other.to_node:
                    return True

            case Relation.IS:
                if (
                    isinstance(self.to_node, State)
                    and ~self.to_node == other.to_node
                ):
                    return True

        return False

    def __format__(self, spec: str):
        ret = f"({self.from_node}, {self.relation}, {self.to_node}"
        if spec == "t":
            ret += f", {self.timestep})"
        else:
            ret += ")"
        return ret

    def __repr__(self):
        return f"GraphEdge{self:t}"

    def __str__(self):
        return f"{self}"


_EdgeLike = (
    GraphEdge
    | tuple[
        GraphNode | _Name | str,
        Relation | str,
        GraphNode | _Name | str | None,
    ]
)


class KnowledgeGraph:
    def __init__(self):
        self.__nodes: list[GraphNode] = []
        self.__edges: list[GraphEdge] = []

        self._node_map: dict[_Id, GraphNode] = {}
        self._name_map: defaultdict[_Name, list[GraphNode]] = defaultdict(list)
        self._edge_map: defaultdict[tuple[_Id, Relation], list[GraphEdge]] = (
            defaultdict(list)
        )
        self.__depreciated_edges = []

    @property
    def nodes(self) -> list[GraphNode]:
        return self.__nodes

    @property
    def edges(self) -> list[GraphEdge]:
        return self.__edges

    @property
    def depreciated_edges(self) -> list[GraphEdge]:
        return self.__depreciated_edges

    def clone(self):
        kg = KnowledgeGraph()
        kg.__nodes = copy.deepcopy(self.__nodes)
        kg.__edges = copy.deepcopy(self.__edges)

        kg._node_map = copy.deepcopy(self._node_map)
        kg._name_map = copy.deepcopy(self._name_map)
        kg._edge_map = copy.deepcopy(self._edge_map)
        kg.__depreciated_edges = copy.deepcopy(self.__depreciated_edges)
        return kg

    @classmethod
    def from_dict(cls, source: dict, timestep: int = -1):
        kg = cls()
        for n in source["nodes"]:
            kg <<= GraphNode.from_dict(n)

        hold_filter = lambda rel: (
            "HOLD" if rel == "HOLDS_RH" or rel == "HOLDS_LH" else rel
        )
        edges = [
            (
                _id(ed["from_id"]),
                Relation[hold_filter(ed["relation_type"].upper())],
                _id(ed["to_id"]),
            )
            for ed in source["edges"]
        ]
        for from_id, relation, to in edges:
            if isinstance(to, int):
                to_node = kg._node_map[to]
            else:
                to_upper = to.upper()
                if to_upper in State.__members__:
                    to_node = State[to_upper]
                elif to_upper in Room.__members__:
                    to_node = Room[to_upper]
                else:
                    raise ValueError(
                        f"Unknown node type: {to}. Must be an int, State, or Room."
                    )

            from_node = kg._node_map[from_id]
            new_edge = GraphEdge(from_node, relation, to_node, timestep)
            kg <<= new_edge

        return kg

    def extend_from_dict(
        self, source: dict, timestep: int, use_refinement: bool = True
    ):
        nodes = (GraphNode.from_dict(n) for n in source["nodes"])
        visible_node_ids: list[_Id] = []
        for n in nodes:
            visible_node_ids.append(n.id)
            if n.id in self._node_map.keys():
                continue
            else:
                self._node_map[n.id] = n
                self._name_map[n.name].append(n)

        hold_filter = lambda rel: (
            "HOLD" if rel == "HOLDS_RH" or rel == "HOLDS_LH" else rel
        )
        edges = [
            (
                _id(edge["from_id"]),
                Relation[hold_filter(edge["relation_type"].upper())],
                _id(edge["to_id"]),
            )
            for edge in source["edges"]
        ]

        visible_edges = [
            (from_id, relation, to_id)
            for from_id, relation, to_id in edges
            if (
                from_id in visible_node_ids
                and to_id in visible_node_ids
                or isinstance(to_id, str)
                and (
                    to_id.upper() in State.__members__
                    or to_id.upper() in Room.__members__
                )
            )
        ]

        hold_check = False
        for from_id, relation, to_id in visible_edges:
            if isinstance(to_id, int):
                to_node = self._node_map[to_id]
            else:
                to_upper = to_id.upper()
                if to_upper in State.__members__:
                    to_node = State[to_upper]
                elif to_upper in Room.__members__:
                    to_node = Room[to_upper]
                else:
                    to_node = self._node_map[to_id]
            from_node = self._node_map[from_id]
            new_edge = GraphEdge(from_node, relation, to_node, timestep)

            if (
                use_refinement
                and (new_edge.from_node.id, new_edge.relation)
                in self._edge_map.keys()
            ):
                if not hold_check:
                    hold_edge = self.search_edges(("character", "HOLD", None))
                    if hold_edge is not None:
                        for edge in hold_edge:
                            if edge.conflict(new_edge) or new_edge == edge:
                                self >>= edge
                    hold_check = True

                for edge in self._edge_map[
                    (new_edge.from_node.id, new_edge.relation)
                ]:
                    if edge.conflict(new_edge) or new_edge == edge:
                        if new_edge == edge:
                            pass
                        else:
                            self.__depreciated_edges.append(edge)
                        self >>= edge

            self <<= new_edge

    def extend(
        self,
        new_graph: KnowledgeGraph | dict,
        /,
        *,
        timestep: int,
        use_refinement: bool = True,
    ):
        if isinstance(new_graph, KnowledgeGraph):
            raise NotImplementedError
        elif isinstance(new_graph, dict):
            self.extend_from_dict(new_graph, timestep, use_refinement)
        else:
            raise TypeError(
                f"Expected KnowledgeGraph or dict, got {type(new_graph)}"
            )

    def __iand__(self, other: KnowledgeGraph | dict):
        return self.extend(other, timestep=-1)

    def add_node(self, new_node: GraphNode):
        self.__nodes.append(new_node)
        self._node_map[new_node.id] = new_node
        self._name_map[new_node.name].append(new_node)

    def add_edge(self, new_edge: GraphEdge):
        self._edge_map[(new_edge.from_node.id, new_edge.relation)].append(
            new_edge
        )
        self.__edges.append(new_edge)

    def __ilshift__(self, other: GraphNode | GraphEdge):
        if isinstance(other, GraphNode):
            self.add_node(other)
        elif isinstance(other, GraphEdge):
            self.add_edge(other)
        else:
            raise TypeError(
                f"Expected GraphNode or GraphEdge, got {type(other)}"
            )
        return self

    def remove_edge(self, edge: _EdgeLike):
        if isinstance(edge, GraphEdge):
            if edge in self._edge_map[(edge.from_node.id, edge.relation)]:
                self._edge_map[(edge.from_node.id, edge.relation)].remove(edge)
            if edge in self.__edges:
                self.__edges.remove(edge)
        else:
            edges = self.search_edges(edge)
            for e in edges or []:
                self._edge_map[(e.from_node.id, e.relation)].remove(e)
                self.__edges.remove(e)

    def __irshift__(self, other: GraphEdge):
        if isinstance(other, GraphEdge):
            self.remove_edge(other)
        else:
            raise TypeError(f"Expected GraphEdge, got {type(other)}")
        return self

    def search_edges(self, edge: _EdgeLike):
        if isinstance(edge, GraphEdge):
            edge = (
                edge.from_node.name,
                edge.relation.name,
                (
                    edge.to_node.name
                    if isinstance(edge.to_node, GraphNode)
                    else str(edge)
                ),
            )

        node_from = edge[0]
        if isinstance(node_from, GraphNode):
            node_from = node_from.name
        else:
            node_from = _Name(node_from)

        if isinstance(edge[1], str):
            relation = Relation[edge[1].upper()]
        else:
            relation = edge[1]

        if node_from not in self._name_map.keys():
            return None

        from_id = self._name_map[node_from][0].id
        if (from_id, relation) not in self._edge_map.keys():
            return None

        return self._edge_map[(from_id, relation)]

    def search_objects_inside_of(self, obj: GraphNode):
        edge = self.search_edges((obj, "INSIDE", None))
        if edge is None:
            return None

        for e in edge:
            if (
                e.to_node.name not in State.__members__
                and e.to_node.name not in Room.__members__
            ):
                return e.to_node.name

        return None

    def search_objects_has_closed(self, obj: GraphNode | _Name | str):
        edge = self.search_edges((obj, "is", None))
        if edge is None:
            return False
        else:
            for e in edge:
                if e.to_node.name.lower() == "closed":
                    return True
            return False

    def search_closing_objects(self, obj: GraphNode | _Name | str):
        if isinstance(obj, GraphNode):
            obj = obj.name
        obj = _Name(obj)

        if self.search_edges(("character", "CLOSE", None)) is None:
            return False
        else:
            closing_objects = self.search_edges(("character", "CLOSE", None))
            if closing_objects is None:
                return False

            for closed_obj in closing_objects:
                if closed_obj.to_node.name == obj:
                    return True

            return False

    @overload
    def search_adjacent_objects(self, edge: GraphEdge, /): ...
    @overload
    def search_adjacent_objects(
        self,
        /,
        *,
        from_node: GraphNode | _Name | str,
        to_node: GraphNode | _Name | str,
    ): ...
    def search_adjacent_objects(
        self,
        edge: GraphEdge | None = None,
        /,
        *,
        from_node: GraphNode | _Name | str | None = None,
        to_node: GraphNode | _Name | str | None = None,
    ):
        if edge is not None:
            if not isinstance(edge, GraphEdge):
                raise TypeError(f"Expected GraphEdge, got {type(edge)}")
            from_node = edge.from_node.name
            to_node = edge.to_node.name

        elif from_node is None or to_node is None:
            raise ValueError(
                "Either edge or both from_node and to_node must be provided."
            )

        if isinstance(from_node, GraphNode):
            from_node = from_node.name
        if isinstance(to_node, GraphNode):
            to_node = to_node.name

        closed_objs = self.search_edges((from_node, "ADJACENT", None))
        if closed_objs is None:
            return False

        for closed_obj in closed_objs:
            if closed_obj.to_node.name == to_node:
                return True

        return False

    def search_room_within(self, obj: GraphNode | _Name | str):
        edge = self.search_edges((obj, "INSIDE", None))
        if edge is None:
            return None

        for e in edge:
            if e.to_node.name in Room.all():
                return e.to_node.name

        return None

    def search_node(self, name: _Name):
        if name not in self._name_map.keys():
            return None
        else:
            return self._name_map[name]

    def refinement(self, new_graph: KnowledgeGraph):
        for new_edge in new_graph.edges:
            removed_edge_idx = []
            for edge in self._edge_map[
                (new_edge.from_node.id, new_edge.relation)
            ]:
                if edge.conflict(new_edge):
                    self.__depreciated_edges.append(edge)
                    removed_edge_idx.append(edge)

            for removed_edge in removed_edge_idx:
                self._edge_map[
                    (removed_edge.from_node.id, removed_edge.relation)
                ].remove(removed_edge)
                self.__edges.remove(removed_edge)

    def retrieve(
        self,
        instructions: str | list[str],
        embedding_fns: Callable[[list[str]], torch.Tensor] | None = None,
        num_edges: int = 50,
        return_type: Literal["default", "with_timestemp"] = "default",
        character_info: bool = True,
        room_info: bool = True,
        hold_info: bool = True,
        replace: bool = False,
        shuffle: bool = True,
    ) -> list[str]:
        if len(self.__edges) == 0:
            return []

        edge_str: list[str] = []
        timestep_edge_str: list[str] = []
        if return_type == "with_timestep":
            for edge in self.__edges:
                timestep_edge_str.append(f"{edge:t}")

        for edge in self.__edges:
            edge_str.append(str(edge))

        if len(edge_str) <= num_edges:
            pass
        elif embedding_fns is None:
            random.shuffle(edge_str)
            edge_str = edge_str[:num_edges]
        else:
            edge_str = [
                s
                for s in edge_str
                if not s.startswith("(character, inside")
                and "adjacent" not in s
            ]  #  not s.startswith("(character, hold") and
            if return_type == "with_timestep":
                timestep_edge_str = [
                    s
                    for s in timestep_edge_str
                    if not s.startswith("(character, inside")
                    and "adjacent" not in s
                ]  #  not s.startswith("(character, hold") and
            if not hold_info:
                edge_str = [
                    s for s in edge_str if not s.startswith("(character, hold")
                ]
                if return_type == "with_timestep":
                    timestep_edge_str = [
                        s
                        for s in timestep_edge_str
                        if not s.startswith("(character, hold")
                    ]
            if not character_info:
                edge_str = [
                    s for s in edge_str if not s.startswith("(character, close")
                ]
                if return_type == "with_timestep":
                    timestep_edge_str = [
                        s
                        for s in timestep_edge_str
                        if not s.startswith("(character, close")
                    ]

            tokenized_edge_str: list[list[str]] = []
            for doc in edge_str:
                doc_tokens = doc[1:-1].split(", ")
                tokenized_edge_str.append(doc_tokens)

            bm25 = BM25Okapi(tokenized_edge_str)
            doc_scores = []
            if not isinstance(instructions, list):
                instructions = [instructions]
            for instruction in instructions:
                doc_tokens = instruction.split()
                mems = list(Relation.__members__) + list(State.__members__)
                for rel_and_sta in mems:
                    if rel_and_sta in doc_tokens:
                        doc_tokens.remove(rel_and_sta)
                doc_scores.append(bm25.get_scores(doc_tokens))
            edge_embedding = embedding_fns(edge_str)
            instruction_embedding = embedding_fns(instructions)

            similarity = np.array(doc_scores) + (
                np.array(
                    compute_similarity(instruction_embedding, edge_embedding)
                )
                * 1e-8
            )
            edge_retrieval_prob = np.max(similarity, axis=0) / np.sum(
                (np.max(similarity, axis=0))
            )
            if return_type == "with_timestep":
                edge_str = timestep_edge_str

            edge_str = list(
                np.random.choice(
                    edge_str,
                    p=edge_retrieval_prob,
                    size=num_edges,
                    replace=replace,
                )
            )
            if character_info:
                if (
                    edges := self.search_edges(("character", "inside", None))
                ) is not None:
                    edge_str.append(str(edges[0]))
            if hold_info:
                if (
                    edges := self.search_edges(("character", "hold", None))
                ) is None or not self.search_edges(("character", "hold", None)):
                    edge_str.append("(character, hold, none)")
                else:
                    edge_str.append(str(edges[0]))
            if room_info:
                room = self.search_room_within("character")
                assert room is not None, "Character must be inside a room."
                adj_room = self.search_edges((room, "adjacent", None))
                for room in adj_room or []:
                    edge_str.append(str(room))

        edge_str = unique(edge_str)
        if shuffle:
            np.random.shuffle(edge_str)

        return [str(e) for e in edge_str]

    def history_retrieve(
        self,
        instructions,
        embedding_fns=None,
        num_edges=50,
        return_type="str",
        character_info=True,
    ):
        if len(self.__depreciated_edges) == 0:
            return " "

        edge_str: list[str] = []
        timestep_edge_str: list[str] = []
        if (
            return_type == "with_timestep"
            or return_type == "with_timestep_list"
        ):
            for edge in self.__depreciated_edges:
                timestep_edge_str.append(f"{edge:t}")

        for edge in self.__depreciated_edges:
            edge_str.append(str(edge))

        if len(edge_str) <= num_edges:
            if (
                return_type == "with_timestep"
                or return_type == "with_timestep_list"
            ):
                edge_str = timestep_edge_str
        elif embedding_fns is None:
            random.shuffle(edge_str)
            edge_str = edge_str[:num_edges]
        else:
            edge_str = [
                s
                for s in edge_str
                if not s.startswith("(character, inside")
                and not s.startswith("(character, hold")
                and "adjacent" not in s
            ]
            if (
                return_type == "with_timestep"
                or return_type == "with_timestep_list"
            ):
                timestep_edge_str = [
                    s
                    for s in timestep_edge_str
                    if not s.startswith("(character, inside")
                    and not s.startswith("(character, hold")
                    and "adjacent" not in s
                ]
            if not character_info:
                edge_str = [
                    s for s in edge_str if not s.startswith("(character, close")
                ]
                if (
                    return_type == "with_timestep"
                    or return_type == "with_timestep_list"
                ):
                    timestep_edge_str = [
                        s
                        for s in timestep_edge_str
                        if not s.startswith("(character, close")
                    ]

            if not edge_str:
                return ""

            tokenized_edge_str: list[list[str]] = []
            for doc in edge_str:
                doc_tokens = doc[1:-1].split(", ")
                tokenized_edge_str.append(doc_tokens)

            bm25 = BM25Okapi(tokenized_edge_str)
            doc_scores = []
            for instruction in instructions:
                doc_tokens = instruction.split()
                mems = list(Relation.__members__) + list(State.__members__)
                for rel_and_sta in mems:
                    if rel_and_sta in doc_tokens:
                        doc_tokens.remove(rel_and_sta)
                doc_scores.append(bm25.get_scores(doc_tokens))
            edge_embedding = embedding_fns(edge_str)
            instruction_embedding = embedding_fns(instructions)

            similarity = (
                np.array(doc_scores)
                + np.array(
                    compute_similarity(instruction_embedding, edge_embedding)
                )
                * 1e-6
            )
            edge_retrieval_prob = np.max(similarity, axis=0) / np.sum(
                (np.max(similarity, axis=0))
            )
            if return_type == "with_timestep":
                edge_str = timestep_edge_str
            edge_str = list(
                np.random.choice(
                    edge_str,
                    p=edge_retrieval_prob,
                    size=min(num_edges, len(edge_str)),
                    replace=False,
                )
            )
            if character_info:
                if edges := self.search_edges(("character", "inside", None)):
                    edge_str.append(str(edges[0]))
                if (
                    self.search_edges(("character", "hold", None)) is None
                    or (edges := self.search_edges(("character", "hold", None)))
                    is None
                ):
                    edge_str.append("(character, hold, none)")
                else:
                    edge_str.append(str(edges[0]))
                room = self.search_room_within("character")
                assert room is not None, "Character must be inside a room."
                adj_room = self.search_edges((room, "adjacent", None))
                for room in adj_room or []:
                    edge_str.append(str(room))

        edge_str = unique(edge_str)
        if return_type == "str" or return_type == "with_timestep":
            return ", ".join(edge_str).lower()
        else:
            return edge_str

    def return_string_tuple(self):
        edge_str = []
        for edge in self.__edges:
            edge_str.append(str(edge))
        return edge_str

    def __str__(self):
        edge_str = []
        for edge in self.__edges:
            edge_str.append(str(edge))
        return ", ".join(edge_str).lower()
