"""
Contains various utilities for handling circuits and performing various
operations with them.
"""
import random
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, Union

from acdc.acdc_utils import TorchIndex

from hypo_interp.types_ import Circuit, CircuitEdge

# We probably will face a speed issue. Because we are creating a lot of
# circuit graphs and we are doing it in a very inefficient way. This could be
# solved super easily by being smarter but we don't have time to do that right now.
# As a result I will cache state of the circuit graph so that we don't have to
# rebuild it every time. This is a bit of a hack but it should work for now.
# TODO: THIS IS A HACK AND SHOULD BE FIXED
# Dict{tuple(Circuit, bool): _CircuitGraphState}

CIRCUIT_CACHE = {}

################################################
# Private helpers functions and data structures
################################################


@dataclass
class _Node:
    name: str
    shape: TorchIndex

    def __hash__(self):
        return hash((self.name, self.shape))


@dataclass
class _Edge:
    sender_node: _Node
    reciever_node: _Node
    value: bool  # doesn'treally matter

    def __hash__(self):
        return hash((self.sender_node, self.reciever_node, self.value))


# Would define above but python doesn't allow forward references
_SOURCES_WITH_POS_EMBED = [
    _Node("hook_embed", TorchIndex([None])),
    _Node("hook_pos_embed", TorchIndex([None])),
]
_SOURCES_WITHOUT_POS_EMBED = [_Node("blocks.0.hook_resid_pre", TorchIndex([None]))]
_POSSIBLE_SOURCES = _SOURCES_WITH_POS_EMBED + _SOURCES_WITHOUT_POS_EMBED


def _to_torch_index(index: Union[TorchIndex, tuple]) -> TorchIndex:
    if isinstance(index, tuple):
        return TorchIndex(index)
    return index


def _reformat_circuit_to_torch_index(circuit: Circuit) -> Circuit:
    """
    reformat circuit to use TorchIndex instead of tuple
    """
    for i, (edge, value) in enumerate(circuit):
        out_name = edge[0]
        out_shape = _to_torch_index(edge[1])
        in_name = edge[2]
        in_shape = _to_torch_index(edge[3])
        circuit[i] = ((out_name, out_shape, in_name, in_shape), value)

    return circuit


def _reachable_nodes(nodes: List[_Node], edges: Dict[_Node, List[_Edge]]) -> Set[_Node]:
    """
    Returns the set of nodes reachable from the node.
    """
    visited = set()

    def dfs(node: _Node):
        visited.add(node)
        if len(edges[node]) == 0:
            return
        for edge in edges[node]:
            child = edge.reciever_node
            if child not in visited:
                dfs(child)

    for node in nodes:
        dfs(node)
    return visited


def _find_sinks(
    nodes: List[_Node],
    edges: Dict[_Node, List[_Edge]],
    reversed_edges: Dict[_Node, List[_Edge]],
) -> List[_Node]:
    """
    Finds the sink in the graph.
    """
    sinks = []
    for node in nodes:
        if len(edges[node]) == 0:
            sinks.append(node)

    # Due to acdc shennaigans we might have paths
    # that are not connected to anything so we have to do
    # some additional checks
    valid_sinks = []
    for sink in sinks:
        possible_sources = _reachable_nodes([sink], reversed_edges)
        # Check any of the possible sources are in the possible sources
        # of the graph
        if any(source in _POSSIBLE_SOURCES for source in possible_sources):
            valid_sinks.append(sink)
    sinks = valid_sinks
    return sinks


def _all_nodes_are_reachable_from_source(
    nodes: List[_Node], edges: Dict[_Node, List[_Edge]], source: _Node
):
    """
    Checks that every node is reachable from the source.
    """
    visited = set()

    def dfs(node: str):
        visited.add(node)
        if len(edges[node]) == 0:
            return
        for edge in edges[node]:
            child = edge.reciever_node
            if child not in visited:
                dfs(child)

    dfs(source)
    return len(visited) == len(nodes)


def _sample_path(
    sources: _Node, sink: _Node, edges: Dict[_Node, List[_Edge]], seed: int = None
) -> List[_Edge]:
    """
    Samples a path from the source to the sink.
    In a uniform manner.
    """
    # set the seed
    if seed is not None:
        random.seed(seed)

    edges_in_path: List[_Edge] = []

    current_node = random.choice(sources)
    while current_node != sink:
        possible_next_nodes = edges[current_node]
        next_edge = random.choice(possible_next_nodes)
        edges_in_path.append(next_edge)
        current_node = next_edge.reciever_node
    return edges_in_path


def _remove_unreachable_nodes(
    sources: List[_Node],
    sinks: List[_Node],
    nodes: List[_Node],
    edges: Dict[_Node, List[_Edge]],
    reversed_edges: Dict[_Node, List[_Edge]],
) -> Tuple[List[_Node], Dict[_Node, List[_Edge]], Dict[_Node, List[_Edge]],]:
    """
    Removes all nodes that are not reachable from the sources and
    that don't lead into the sink. This should take of handling
    acdc's placeholder edges.

    Returns:
        nodes: the nodes that are reachable from the sources and lead into the sink
        edges: the edges that have nodes with the above property
        reversed_edges: edges but in the opposite direction
    """
    # The strategy that we will use is to do a dfs from the sources
    # and then remove all nodes that are not reachable from the sources
    # and then do a dfs from the sink and remove all nodes that don't
    # lead into the sink.

    new_nodes = []
    new_edges = {}
    new_reversed_edges = {}

    nodes_reachable_from_sources = _reachable_nodes(sources, edges)
    nodes_leading_into_sink = _reachable_nodes(sinks, reversed_edges)
    new_nodes_set = set(nodes_reachable_from_sources).intersection(
        set(nodes_leading_into_sink)
    )
    new_nodes = list(new_nodes_set)

    # Fill in the new edges
    for node in new_nodes:
        new_edges[node] = []
        new_reversed_edges[node] = []

    # Remove any edge that has a node that is not in the new nodes
    for node in new_nodes:
        new_edges[node] = []
        for edge in edges[node]:
            if edge.reciever_node in new_nodes_set:
                new_edges[node].append(edge)

    for node in new_nodes:
        new_reversed_edges[node] = []
        for edge in reversed_edges[node]:
            if edge.reciever_node in new_nodes_set:
                new_reversed_edges[node].append(edge)

    return new_nodes, new_edges, new_reversed_edges


def _edges_to_circuit(edges: List[_Edge]) -> Circuit:
    """
    Converts a list of edges to a circuit.
    """
    circuit = []
    for edge in edges:
        # the circuit edges are in the reverse order of the edges
        circuit_edge: CircuitEdge = (
            edge.reciever_node.name,
            edge.reciever_node.shape,
            edge.sender_node.name,
            edge.sender_node.shape,
        )
        circuit.append((circuit_edge, edge.value))
    return circuit


# ----------------------------#
# Stuff for CIRCUIT_CACHE
# ----------------------------#
# TODO: THIS IS A HACK AND SHOULD BE FIXED


@dataclass
class _CircuitGraphState:
    nodes: List[_Node]
    edges: Dict[_Node, List[_Edge]]
    reversed_edges: Dict[_Node, List[_Edge]]
    sources: List[_Node]
    sink: _Node


class _CircuitGraphArgs:
    def __init__(self, circuit: Circuit, use_pos_embed: bool):
        # convert circuit to uniform format
        self.circuit = _reformat_circuit_to_torch_index(circuit)
        self.use_pos_embed = use_pos_embed

    def __hash__(self):
        circuit_set = set(self.circuit)
        # make the set hashable by making it a frozenset
        circuit_set = frozenset(circuit_set)
        return hash((circuit_set, self.use_pos_embed))


################################
# Private classes and functions
################################


class _CircuitGraph:
    """
    Simple data structure for handling circuits in a proper
    graph format. This class is not the true representation of Circuit
    but we used for simple utility like stuff.
    """

    def __init__(self, circuit: Circuit, use_pos_embed: bool):
        # Check if we have the state in the cache
        args = _CircuitGraphArgs(circuit, use_pos_embed)
        if args in CIRCUIT_CACHE:
            self._restore_from_cache(circuit, use_pos_embed)
        else:
            self._full_init(circuit, use_pos_embed)
            CIRCUIT_CACHE[args] = _CircuitGraphState(
                nodes=self.nodes,
                edges=self.edges,
                reversed_edges=self.reversed_edges,
                sources=self.sources,
                sink=self.sink,
            )

    def _restore_from_cache(self, circuit: Circuit, use_pos_embed: bool):
        """
        Restores the state of the circuit graph from the cache.
        """
        args = _CircuitGraphArgs(circuit, use_pos_embed)
        state = CIRCUIT_CACHE[args]
        self.nodes = state.nodes
        self.edges = state.edges
        self.reversed_edges = state.reversed_edges
        self.sources = state.sources
        self.sink = state.sink

    def _full_init(self, circuit: Circuit, use_pos_embed: bool):
        """
        Initializes a circuit graph from a circuit.

        use_pos_embed:
            set to true for Tracr tasks
        """
        # Builds the graph from the Circuit type data structure
        nodes, edges, reversed_edges = self._build_graph(circuit)
        sinks = _find_sinks(nodes, edges, reversed_edges)
        sources = _SOURCES_WITH_POS_EMBED if use_pos_embed else _SOURCES_WITHOUT_POS_EMBED

        # Remove any nodes that are not reachable from the sources
        # and that don't lead into the sink. (This should take care of
        # acdc's placeholder edges)
        nodes, edges, reversed_edges = _remove_unreachable_nodes(
            sources=sources,
            sinks=sinks,
            nodes=nodes,
            edges=edges,
            reversed_edges=reversed_edges,
        )

        # -------------------- #
        # State of the graph   #
        # -------------------- #
        self.nodes: List[_Node] = nodes
        # Dictionary of edges outgoing from a node
        self.edges: Dict[_Node, List[_Edge]] = edges
        # Dicionary of edges in the flipped direction
        self.reversed_edges: Dict[_Node, List[_Edge]] = reversed_edges
        # There might be multiple sources due to placeholder edges in acdc
        # but there should be only one sink (there is also only one true source)
        # These are acdc ideosyncracies.
        self.sources = sources
        self.sink = sinks[0]

        # Do some additional checks to make sure the graph is valid
        self._check_graph(self.nodes, self.edges, self.reversed_edges)

    def __len__(self):
        return len(self.all_edges)

    @property
    def all_edges(self) -> List[_Edge]:
        """
        Returns all the edges in the graph.
        """
        all_edges = []
        for node in self.nodes:
            all_edges.extend(self.edges[node])
        return all_edges

    def sample_circuit(self, minimum_number_of_edges: int, seed: int = None) -> Circuit:
        """
        Samples a circuit from the graph.

        The circuit is sampled by sampling paths from the source to the sink till
        the number of edges in the circuit is greater than or equal to the minimum
        number of edges requested by the user.

        Each path is sampled by uniformly selecting a node from the current node (initia
        lly the source) till the sink is reached.
        """
        if minimum_number_of_edges < 0:
            raise ValueError("Minimum number of edges must be non-negative")

        # set the seed
        if seed is not None:
            random.seed(seed)

        edges_in_circuit: Set[_Edge] = set()
        num_edges = 0
        while num_edges < minimum_number_of_edges:
            path = _sample_path(self.sources, self.sink, self.edges)
            for edge in path:
                edges_in_circuit.add(edge)
            num_edges = len(edges_in_circuit)

        return _edges_to_circuit(edges_in_circuit)

    def make_inflated_from_super_set(
        self, super_set_circuit_graph, inflate_size: int, seed: int = None
    ) -> Circuit:
        # super set circuit is the complete circuit
        if inflate_size < 0:
            raise ValueError("inflate size edges must be non-negative")
        if seed is not None:
            random.seed(seed)

        edges_in_circuit: Set[_Edge] = set()
        for _, edge_list in self.edges.items():
            for edge in edge_list:
                edges_in_circuit.add(edge)

        num_original_edges = len(edges_in_circuit)
        num_edges = num_original_edges
        while num_edges < inflate_size + num_original_edges:
            path = _sample_path(self.sources, self.sink, super_set_circuit_graph.edges)
            for edge in path:
                edges_in_circuit.add(edge)
            num_edges = len(edges_in_circuit)

        # It is ok to not care about the nodes because this is implicitly3
        # handled by the _Edge datastucture
        return _edges_to_circuit(edges_in_circuit)

    def make_original_circuit(self) -> Circuit:
        return _edges_to_circuit(self.all_edges)

    def make_reformatted_circuit(self) -> Circuit:
        return _edges_to_circuit(self.all_edges)

    def _build_graph(
        self, circuit: Circuit
    ) -> Tuple[List[_Node], Dict[_Node, List[_Edge]], Dict[_Node, List[_Edge]],]:
        """
        Builds the graph from the circuit.
        """

        edges = {}
        reversed_edges = {}
        nodes = set()

        # value should be ignored but we keep track of it for
        # consistency with acdc's definition of a circuit

        for (edge, value) in circuit:
            sender_name: str = edge[2]
            sender_shape: TorchIndex = _to_torch_index(edge[3])

            reciever_name: str = edge[0]
            reciever_shape: TorchIndex = _to_torch_index(edge[1])

            reciever_node = _Node(reciever_name, reciever_shape)
            sender_node = _Node(sender_name, sender_shape)

            nodes.add(reciever_node)
            nodes.add(sender_node)

            forward_edge = _Edge(
                sender_node=sender_node, reciever_node=reciever_node, value=value
            )
            reversed_edge = _Edge(
                sender_node=reciever_node, reciever_node=sender_node, value=value
            )

            # We don't use DefaultDict because we don't want silent errors
            if sender_node not in edges:
                edges[sender_node] = []

            if reciever_node not in reversed_edges:
                reversed_edges[reciever_node] = []

            edges[sender_node].append(forward_edge)
            reversed_edges[reciever_node].append(reversed_edge)

        # Add empty edges for nodes that don't have any edges
        nodes = list(nodes)
        for node in nodes:
            if node not in edges:
                edges[node] = []

            if node not in reversed_edges:
                reversed_edges[node] = []

        return nodes, edges, reversed_edges

    def _check_graph(
        self,
        nodes: List[str],
        edges: Dict[str, List[str]],
        reversed_edges: Dict[str, List[str]],
    ) -> None:
        """
        Checks that the graph satisfies basic invariances we expect
        and use in our code.

        Raises a ValueError if the graph does not satisfy these invariances.
        """
        sources = _find_sinks(nodes, reversed_edges, edges)
        sinks = _find_sinks(nodes, edges, reversed_edges)

        # Check that the source we have is in the true sources of the graph
        # We might have more than one source due the convention in acdc of having
        # placeholder edges. However, there should be only one true source.
        for source in self.sources:
            if source not in sources:
                msg = f"Source {source} is not a source of the graph. Please match the "
                msg += f"convension of the source being {source} in the graph."
                msg += f"\nSources: {sources}"
                raise ValueError(msg)

        if len(sinks) != 1:
            msg = "Computational graph should have exactly one sink. "
            msg += f"Found {len(sinks)} sinks."
            msg += f"\nSinks: {sinks}"
            raise ValueError(msg)

        # Check that the sink is reachable from all nodes
        # we don't use the opposite convention because it might be the case that

        # if not _all_nodes_are_reachable_from_source(nodes, reversed_edges, self.sink):
        #    raise ValueError("Not all nodes can reach the sink")


################################################
# Public functions
################################################

# TODO: ALL OF THESE FUNCTIONS SHOULD BE DEPRECATED
# AND WE SHOULD USE THE CIRCUIT GRAPH CLASS INSTEAD
# AS THE NATIVE REPRESENTATION OF A CIRCUIT IN OUR CODE.
# THIS WILL MAKE THE CODE MUCH CLEANER AND EASIER TO READ
# AND FASTER TO RUN.


def sample_circuit_from_circuit(
    circuit: Circuit, minimum_number_of_edges: int, use_pos_embed=False, seed: int = 0
) -> Circuit:
    """
    Samples a circuit from the circuit.
    """
    circuit_graph = _CircuitGraph(circuit, use_pos_embed)
    sampled_circuit = circuit_graph.sample_circuit(minimum_number_of_edges, seed)
    return sampled_circuit


def sample_inflated_circuit_from_circuit(
    circuit_to_inflate: Circuit,
    complete_circuit: Circuit,
    inflate_size: int,
    use_pos_embed=False,
    seed: int = 0,
) -> Circuit:
    """
    infalte a circuit from the circuit.
    """
    circuit_graph_canonical = _CircuitGraph(circuit_to_inflate, use_pos_embed)
    circuit_graph_complete = _CircuitGraph(complete_circuit, use_pos_embed)

    # check on is the subest of the other one

    assert set(circuit_graph_canonical.edges).issubset(
        set(circuit_graph_complete.edges)
    ), "canonical circuit is not a subset of the complete circuit"
    sampled_circuit = circuit_graph_canonical.make_inflated_from_super_set(
        circuit_graph_complete, inflate_size, seed
    )
    return sampled_circuit


def reformat_circuit_to_torch_index(circuit: Circuit) -> Circuit:
    """
    reformat circuit to use TorchIndex instead of tuple.
    We just need the forward reference of but keep
    the function here for public use
    """
    return _reformat_circuit_to_torch_index(circuit)


def compute_actual_circuit_size(circuit: Circuit, use_pos_embed) -> int:
    """
    Computes the actual size of the circuit. Ideally, this should be the same
    as the number of edges in the circuit but due to the way acdc handles
    placeholder edges this might not be the case
    """
    circuit_graph = _CircuitGraph(circuit, use_pos_embed)
    return len(circuit_graph)
