from typing import List, Optional, Sequence, cast

import numpy as np

from numpy.typing import NDArray
from cirkit.region_graph import RegionGraph, RegionNode, PartitionNode
from cirkit.region_graph.utils import HypercubeScopeCache


def _partition_node(
    graph: RegionGraph,
    node: RegionNode,
    replica_idx: int,
    num_parts: Optional[int] = None,
    random: Optional[bool] = False,
    proportions: Optional[Sequence[float]] = None,
) -> List[RegionNode]:
    """Partition a region node randomly and add to RG.

    Args:
        graph (RegionGraph): The region graph to hold the partitioning.
        node (RegionNode): The node to partition.
        num_parts (Optional[int], optional): The number of parts to partition. If not provided, \
            will be inferred from proportions. Defaults to None.
        proportions (Optional[Sequence[float]], optional): The proportions of each part, can be \
            unnormalized. If not provided, will equally divide to num_parts. Defaults to None.

    Returns:
        List[RegionNode]: The region nodes forming the partitioning.
    """
    scope_list = list(node.scope)
    if random: np.random.shuffle(scope_list)

    # ANNOTATE: Numpy has typing issues.
    split: NDArray[np.float64]  # Unnormalized split points including 0 and 1.
    if proportions is None:
        assert num_parts, "Must provide at least one of num_parts and proportions."
        split = np.arange(num_parts + 1, dtype=np.float64)
    else:
        split = np.array([0.0] + list(proportions), dtype=np.float64).cumsum()

    # ANNOTATE: ndarray.tolist gives Any.
    # CAST: Numpy has typing issues.
    # IGNORE: Numpy has typing issues.
    split_point: List[int] = (
        cast(NDArray[np.float64], split / split[-1] * len(scope_list))  # type: ignore[misc]
        .round()
        .astype(np.int64)
        .tolist()
    )

    # ANNOTATE: Specify content for empty container.
    region_nodes: List[RegionNode] = []
    for l, r in zip(split_point[:-1], split_point[1:]):
        # FUTURE: for l, r in itertools.pairwise(split_point) in 3.10
        if l < r:  # A region must have as least one var, otherwise we skip it.
            region_node = RegionNode(scope_list[l:r], replica_idx=replica_idx)
            region_nodes.append(region_node)

    if len(region_nodes) == 1:
        # Only one region, meaning cannot partition anymore, and we just keep the original node as
        # the leaf.
        return [node]

    # graph.add_partitioning(node, region_nodes)
    region, sub_regions = node, region_nodes
    partition = PartitionNode(region.scope)
    graph.add_edge(partition, region)
    for sub_region in sub_regions:
        graph.add_edge(sub_region, partition)
    return region_nodes


# DISABLE: We use function name with upper case to mimic a class constructor.
# pylint: disable-next=invalid-name
def BinaryTree(
    num_vars: int,
    num_repetitions: Optional[int] = 1,
    depth: Optional[int] = None,
    random: Optional[bool] = False,
    seed: Optional[int] = 42
) -> RegionGraph:
    np.random.seed(seed)
    if depth is None:
        depth = int(np.ceil(np.log2(num_vars)))
    graph = RegionGraph()
    root = RegionNode(range(num_vars))
    graph.add_node(root)
    for replica_idx in range(num_repetitions):
        layer = [root]
        for _ in range(depth):
            layer = sum((_partition_node(graph, node, replica_idx=replica_idx, num_parts=2, random=random) for node in layer), [])
    return graph


def LinearRegionGraph(
    num_variables: int,
    num_repetitions: int = 1,
    randomize: bool = False,
    seed: int = 42,
    extra_root: Optional[bool] = False
) -> RegionGraph:
    root = RegionNode(range(num_variables))
    rg = RegionGraph()
    rg.add_node(root)
    if extra_root:
        partition_node = PartitionNode(set(root.scope))
        rg.add_edge(partition_node, root)
        root = RegionNode(range(num_variables))
        rg.add_edge(root, partition_node)
    random_state = np.random.RandomState(seed)

    for _ in range(num_repetitions):
        parent_node = root
        vars = list(range(num_variables))
        if randomize:
            random_state.shuffle(vars)
        for i, v in enumerate(vars[:-1]):
            partition_node = PartitionNode(set(parent_node.scope))
            rg.add_edge(partition_node, parent_node)
            leaf_node = RegionNode({v})
            if i == num_variables - 2:
                rest_node = RegionNode({vars[-1]})
            else:
                rest_node = RegionNode({j for j in vars[i + 1:]})
            rg.add_edge(leaf_node, partition_node)
            rg.add_edge(rest_node, partition_node)
            parent_node = rest_node

    return rg


def _merge_2_regions(regions: List[RegionNode], graph: RegionGraph) -> RegionNode:
    """Make the structure to connect 2 children.

    Args:
        regions (List[RegionNode]): The children regions.
        graph (nx.DiGraph): The region graph.

    Returns:
        RegionNode: The merged region node.
    """
    assert len(regions) == 2

    scope = regions[0].scope.union(regions[1].scope)
    partition_node = PartitionNode(scope)
    region_node = RegionNode(scope)

    graph.add_edge(regions[0], partition_node)
    graph.add_edge(regions[1], partition_node)
    graph.add_edge(partition_node, region_node)

    return region_node


def _merge_4_regions(regions: List[RegionNode], graph: RegionGraph) -> RegionNode:
    """Make the structure to connect 4 children with structured-decomposability \
        (horizontal then vertical).

    Args:
        regions (List[RegionNode]): The children regions.
        graph (nx.DiGraph): The region graph.

    Returns:
        RegionNode: The merged region node.
    """
    assert len(regions) == 4
    # MERGE regions
    whole_scope = regions[0].scope.union(regions[1].scope).union(regions[2].scope).union(regions[3].scope)
    whole_partition = PartitionNode(whole_scope)
    graph.add_edge(regions[0], whole_partition)
    graph.add_edge(regions[1], whole_partition)
    graph.add_edge(regions[2], whole_partition)
    graph.add_edge(regions[3], whole_partition)

    whole_region = RegionNode(whole_scope)
    graph.add_edge(whole_partition, whole_region)

    return whole_region


def _square_from_buffer(buffer: List[List[RegionNode]], i: int, j: int) -> List[RegionNode]:
    """Get the children of the current position from the buffer.

    Args:
        buffer (List[List[RegionNode]]): The buffer of all children.
        i (int): The i coordinate currently.
        j (int): The j coordinate currently.

    Returns:
        List[RegionNode]: The children nodes.
    """
    children = [buffer[i][j]]
    # TODO: rewrite: len only useful at 2n-1 boundary
    if len(buffer) > i + 1:
        children.append(buffer[i + 1][j])
    if len(buffer[i]) > j + 1:
        children.append(buffer[i][j + 1])
    if len(buffer) > i + 1 and len(buffer[i]) > j + 1:
        children.append(buffer[i + 1][j + 1])
    return children


# pylint: disable-next=too-many-locals,invalid-name
def RealQuadTree(width: int, height: int, final_sum=False) -> RegionGraph:
    """Get quad RG.

        Args:
            width (int): Width of scope.
            height (int): Height of scope.
            struct_decomp (bool, optional): Whether structured-decomposability \
                is required. Defaults to False.

    Returns:
        RegionGraph: The RG.
    """
    assert width == height and width > 0  # TODO: then we don't need two

    shape = (width, height)

    hypercube_to_scope = HypercubeScopeCache()

    buffer: List[List[RegionNode]] = [[] for _ in range(width)]

    graph = RegionGraph()

    # Add Leaves
    for i in range(width):
        for j in range(height):
            hypercube = ((i, j), (i + 1, j + 1))

            c_scope = hypercube_to_scope(hypercube, shape)
            c_node = RegionNode(c_scope)
            graph.add_node(c_node)
            buffer[i].append(c_node)

    lr_choice = 0  # left right # TODO: or choose from 0 and 1?
    td_choice = 0  # top down

    old_buffer_height = height
    old_buffer_width = width
    old_buffer = buffer

    # TODO: also no need to have two for h/w
    while old_buffer_width > 1 and old_buffer_height > 1:  # pylint: disable=while-used
        buffer_height = (old_buffer_height + 1) // 2
        buffer_width = (old_buffer_width + 1) // 2

        buffer = [[] for _ in range(buffer_width)]

        for i in range(buffer_width):
            for j in range(buffer_height):
                regions = _square_from_buffer(old_buffer, 2 * i + lr_choice, 2 * j + td_choice)
                if len(regions) == 1:
                    buf = regions[0]
                elif len(regions) == 2:
                    buf = _merge_2_regions(regions, graph)
                else:
                    buf = _merge_4_regions(regions, graph)
                buffer[i].append(buf)

        old_buffer = buffer
        old_buffer_height = buffer_height
        old_buffer_width = buffer_width

    # add root
    if final_sum:
        roots = list(graph.output_nodes)
        assert len(roots) == 1
        root = roots[0]
        partition_node = PartitionNode(root.scope)
        mixed_root = RegionNode(root.scope)
        graph.add_node(root)
        graph.add_node(partition_node)
        graph.add_edge(root, partition_node)
        graph.add_edge(partition_node, mixed_root)

    assert graph.is_smooth
    assert graph.is_decomposable

    # note: why if adding a final sum is not structured decomposable anymore?
    if not final_sum:
        assert graph.is_structured_decomposable

    return graph
