from collections.abc import Sequence
import itertools
from collections import defaultdict

import numpy as np

from cirkit.templates.region_graph.algorithms.quad import QuadTree as SingleQuadTree
from cirkit.templates.region_graph.algorithms.utils import HypercubeToScope
from cirkit.templates.region_graph.graph import (
    PartitionNode,
    RegionGraph,
    RegionGraphNode,
    RegionNode,
)
from cirkit.utils.scope import Scope


def union_region_graphs(rgs: Sequence[RegionGraph]) -> RegionGraph:
    assert len(rgs) > 0
    root, = tuple(rgs[0].outputs)
    assert all(root.scope == tuple(rg.outputs)[0].scope for rg in rgs)
    nodes: list[RegionGraphNode] = [root]
    in_nodes: dict[RegionGraphNode, list[RegionGraphNode]] = {root: []}

    for rg in rgs:
        for n in rg.topological_ordering():
            if isinstance(n, RegionNode):
                if n.scope != root.scope:
                    nodes.append(n)
                    in_nodes[n] = rg.region_inputs(n)
                continue
            assert isinstance(n, PartitionNode)
            nodes.append(n)
            in_nodes[n] = rg.partition_inputs(n)
            if n.scope == root.scope:
                in_nodes[root].append(n)
    return RegionGraph(nodes, in_nodes, [root])


# pylint: disable-next=invalid-name
def QuadTree(shape: tuple[int, int, int], *, num_repetitions: int = 1, num_patch_splits: int = 2) -> RegionGraph:
    assert num_repetitions > 0
    qts = [SingleQuadTree(shape, num_patch_splits=num_patch_splits) for _ in range(num_repetitions)]
    return union_region_graphs(qts)


# pylint: disable-next=invalid-name
def RandomQuadTree2(shape: tuple[int, int, int], *, num_repetitions: int = 1, seed: int = 42) -> RegionGraph:
    r"""Constructs a Random Quad Tree region graph.
    Whether to split horizontally or vertically at each depth is decided randomly.

    Args:
        shape: The image shape $(C, H, W)$, where $H$ is the height, $W$ is the width,
            and $C$ is the number of channels.
        num_repetitions: The number of Random Quad Tree region graph repetitions. Defaults to 1.
        seed: The seed used for randomly choosing whether to split horizontally or vertically at each depth.

    Returns:
        RegionGraph: A Random Quad Tree region graph.

    Raises:
        ValueError: The image shape is not valid.
        ValueError: The number of patches to split is not valid.
    """
    if len(shape) != 3:
        raise ValueError("RandomQuadTree2 region graph only works for images")
    num_channels, height, width = shape
    if num_channels <= 0 or height <= 0 or width <= 0:
        raise ValueError("The number of channels, the height and the width must be positive")

    # Instantiate the random state
    random_state = np.random.RandomState(seed)

    # Padding using Scope({num_var}) which is one larger than range(num_var).
    # DISABLE: This is considered a constant here, although RegionNode is mutable.
    PADDING = RegionNode({height * width})  # pylint: disable=invalid-name

    # An object mapping rectangles of coordinates into variable scopes
    hypercube_to_scope = HypercubeToScope(shape)

    # The list of region and partition nodes
    nodes: list[RegionGraphNode] = []

    # A map to each region/partition node to its children
    in_nodes: dict[RegionGraphNode, list[RegionGraphNode]] = defaultdict(list)

    # Instantiate the root region node
    root = RegionNode(range(num_channels * height * width))
    nodes.append(root)

    def merge_regions_(rgn_in: list[RegionNode]) -> RegionNode:
        """Merge 2 or 4 regions to a larger region."""
        assert len(rgn_in) in {2, 4}
        scope = Scope.union(*tuple(rgn.scope for rgn in rgn_in))
        if scope == root.scope:
            rgn = root
        else:
            rgn = RegionNode(scope)
            nodes.append(rgn)
        ptn = PartitionNode(scope)
        nodes.append(ptn)
        in_nodes[rgn].append(ptn)
        in_nodes[ptn] = rgn_in
        return rgn

    def merge_4_regions_tree_(rgn_in: list[RegionNode], *, merge_horizontally_first: bool = True) -> RegionNode:
        # Merge 4 regions to a larger region, with structured-decomposablility
        assert len(rgn_in) == 4

        if merge_horizontally_first:
            # Merge horizontally
            region_top = merge_regions_(rgn_in[:2])
            region_bot = merge_regions_(rgn_in[2:])
            # Merge vertically
            return merge_regions_([region_top, region_bot])

        # Merge vertically
        region_left = merge_regions_(rgn_in[::2])
        region_right = merge_regions_(rgn_in[1::2])
        # Merge horizontally
        return merge_regions_([region_left, region_right])

    for _ in range(num_repetitions):
        # The regions of the current layer, in shape (H, W). The same PADDING object is reused.
        grid = [[PADDING] * (width + 1) for _ in range(height + 1)]

        # Add input region nodes
        for i, j in itertools.product(range(height), range(width)):
            scope = hypercube_to_scope[((0, i, j), (num_channels, i + 1, j + 1))]
            rgn = RegionNode(scope)
            grid[i][j] = rgn
            nodes.append(rgn)

        # Merge frontier by frontier, loop until (H, W)==(1, 1).
        cur_height, cur_width = height, width
        while cur_height > 1 or cur_width > 1:
            cur_height = (cur_height + 1) // 2
            cur_width = (cur_width + 1) // 2
            prev_grid, grid = grid, [[PADDING] * (width + 1) for _ in range(cur_height + 1)]
            merge_horizontally_first = random_state.rand() < 0.5

            for i, j in itertools.product(range(cur_height), range(cur_width)):
                regions = [  # Filter valid regions in the 4 possible sub-regions.
                    rgn
                    for rgn in (
                        prev_grid[i * 2][j * 2],
                        prev_grid[i * 2][j * 2 + 1],
                        prev_grid[i * 2 + 1][j * 2],
                        prev_grid[i * 2 + 1][j * 2 + 1],
                    )
                    if rgn != PADDING
                ]
                if len(regions) == 1:
                    node = regions[0]
                elif len(regions) == 2:
                    node = merge_regions_(regions)
                elif len(regions) == 4:
                    node = merge_4_regions_tree_(regions, merge_horizontally_first=merge_horizontally_first)
                else:
                    # NOTE: In the above if/elif, we made all conditions explicit to make it more
                    #       readable and also easier for static analysis inside the blocks. Yet the
                    #       completeness cannot be inferred and is only guaranteed by larger picture.
                    #       Also, should anything really go wrong, we will hit this guard statement
                    #       instead of going into a wrong branch.
                    assert False, "This should not happen"
                grid[i][j] = node

    return RegionGraph(nodes, in_nodes, outputs=[root])


def DecisionLeafQuadGraph(shape: tuple[int, int, int], *, max_patch_size: int = 8) -> RegionGraph:
    if len(shape) != 3:
        raise ValueError("RandomQuadTree2 region graph only works for images")
    num_channels, height, width = shape
    if num_channels <= 0 or height <= 0 or width <= 0:
        raise ValueError("The number of channels, the height and the width must be positive")

    # An object mapping rectangles of coordinates into variable scopes
    hypercube_to_scope = HypercubeToScope(shape)

    # The list of region and partition nodes
    nodes: list[RegionGraphNode] = []

    # A map to each region/partition node to its children
    in_nodes: dict[RegionGraphNode, list[RegionGraphNode]] = defaultdict(list)

    # Instantiate the root region node
    root = RegionNode(range(num_channels * height * width))
    nodes.append(root)

    # The set of region nodes to split
    rgn_to_split = [(root, (0, 0), (height, width), True)]

    def split_vertically(rgn: RegionNode, sa: tuple[int, int], sb: tuple[int, int], *, next_hsplit: bool = True):
        ah, aw = sa
        bh, bw = sb
        left_a, left_b = (0, ah, aw), (num_channels, bh, (aw + bw) // 2)
        right_a, right_b = (0, ah, (aw + bw) // 2), (num_channels, bh, bw)
        left_scope = hypercube_to_scope[left_a, left_b] if left_a[2] < left_b[2] else Scope()
        right_scope = hypercube_to_scope[right_a, right_b] if right_a[2] < right_b[2] else Scope()
        #
        left_rgn, right_rgn = RegionNode(left_scope), RegionNode(right_scope)
        horizontal_ptn = PartitionNode(left_scope | right_scope)
        nodes.append(left_rgn)
        nodes.append(right_rgn)
        nodes.append(horizontal_ptn)
        in_nodes[rgn].append(horizontal_ptn)
        in_nodes[horizontal_ptn] = [left_rgn, right_rgn]
        rgn_to_split.append((left_rgn, left_a[1:], left_b[1:], next_hsplit))
        rgn_to_split.append((right_rgn, right_a[1:], right_b[1:], next_hsplit))

    def split_horizontally(rgn: RegionNode, sa: tuple[int, int], sb: tuple[int, int], *, next_hsplit: bool = True):
        ah, aw = sa
        bh, bw = sb
        above_a, above_b = (0, ah, aw), (num_channels, (ah + bh) // 2, bw)
        below_a, below_b = (0, (ah + bh) // 2, aw), (num_channels, bh, bw)
        above_scope = hypercube_to_scope[above_a, above_b] if above_a[1] < above_b[1] else Scope()
        below_scope = hypercube_to_scope[below_a, below_b] if below_a[1] < below_b[1] else Scope()
        #
        above_rgn, below_rgn = RegionNode(above_scope), RegionNode(below_scope)
        vertical_ptn = PartitionNode(above_scope | below_scope)
        nodes.append(above_rgn)
        nodes.append(below_rgn)
        nodes.append(vertical_ptn)
        in_nodes[rgn].append(vertical_ptn)
        in_nodes[vertical_ptn] = [above_rgn, below_rgn]
        rgn_to_split.append((above_rgn, above_a[1:], above_b[1:], next_hsplit))
        rgn_to_split.append((below_rgn, below_a[1:], below_b[1:], next_hsplit))

    while rgn_to_split:
        rgn, sa, sb, hsplit = rgn_to_split.pop()
        if len(rgn.scope) == num_channels:
            continue
        ah, aw = sa
        bh, bw = sb
        if bh - ah < max_patch_size or bw - aw < max_patch_size:
            if not hsplit:
                if bw - aw > 1:
                    split_vertically(rgn, sa, sb, next_hsplit=True)
                else:
                    split_horizontally(rgn, sa, sb, next_hsplit=False)
            else:
                if bh - ah > 1:
                    split_horizontally(rgn, sa, sb, next_hsplit=False)
                else:
                    split_vertically(rgn, sa, sb, next_hsplit=True)
            continue

        if bh - ah > 1:
            split_horizontally(rgn, sa, sb, next_hsplit=False)
        if bw - aw > 1:
            split_vertically(rgn, sa, sb, next_hsplit=True)

    return RegionGraph(nodes, in_nodes, [root])

