"""Faithful reimplementation of Reingold-Tilford tree layout algorithm.

Translated from d3-flextree.js, which is a JavaScript implementation of the
Reingold-Tilford tree layout algorithm. The original JavaScript code is
licensed under the MIT license.
"""

from typing import List
from dataclasses import dataclass


class LayoutNode(object):
    def __init__(self, size, children: List["LayoutNode"] = None, data=None):
        self.size = size
        self.children = children
        self.data = data

        self.x = 0
        self.y = 0
        self.relX = 0
        self.prelim = 0
        self.shift = 0
        self.change = 0
        self.lExt = self
        self.lExtRelX = 0
        self.lThr = None
        self.rExt = self
        self.rExtRelX = 0
        self.rThr = None

    @property
    def hasChildren(self):
        return self.children is not None

    @property
    def firstChild(self):
        return self.children[0]

    @property
    def lastChild(self):
        return self.children[-1]

    @property
    def top(self):
        return self.y

    @property
    def bottom(self):
        return self.y + self.size[1]

    @property
    def left(self):
        return self.x - self.size[0] / 2

    @property
    def right(self):
        return self.x + self.size[0] / 2

    @property
    def xSize(self):
        return self.size[0]

    @property
    def ySize(self):
        return self.size[1]

    def spacing(self, other):
        return 5


def layout_children(w: LayoutNode, y=0):
    w.y = y
    last_lows = None
    i = 0
    c = w.children or []
    for kid in c:
        layout_children(kid, w.y + w.ySize)
        low_y = (kid.lExt if i == 0 else kid.rExt).bottom
        if i != 0:
            separate(w, i, last_lows)
        last_lows = update_lows(low_y, i, last_lows)
        i += 1
    shift_change(w)
    position_root(w)
    return w


# Position root between children, taking into account their modifiers
def position_root(w: LayoutNode):
    if w.hasChildren:
        k0 = w.firstChild
        kf = w.lastChild
        prelim = (
            k0.prelim + k0.relX - k0.xSize / 2 + kf.relX + kf.prelim + kf.xSize / 2
        ) / 2
        w.prelim = prelim
        w.lExt = k0.lExt
        w.lExtRelX = k0.lExtRelX
        w.rExt = kf.rExt
        w.rExtRelX = kf.rExtRelX


# Process shift and change for all children, to add intermediate spacing to
# each child's modifier.
def shift_change(w: LayoutNode):
    children = w.children or []
    acc = [0, 0]
    for child in children:
        shift_sum = acc[0] + child.shift
        change_sum = acc[1] + shift_sum + child.change
        child.relX += change_sum
        acc = [shift_sum, change_sum]


@dataclass
class Lows:
    lowY: float
    index: int
    next: "Lows" = None


# Make/maintain a linked list of the indexes of left siblings and their
# lowest vertical coordinate.
def update_lows(low_y, index, last_lows):
    # Remove siblings that are hidden by the new subtree.
    while last_lows is not None and low_y >= last_lows.lowY:
        last_lows = last_lows.next

    # Return namedtuple instead of dict.
    return Lows(lowY=low_y, index=index, next=last_lows)


# Move subtree by changing relX.
def move_subtree(subtree: LayoutNode, distance):
    subtree.relX += distance
    subtree.lExtRelX += distance
    subtree.rExtRelX += distance


def distribute_extra(w: LayoutNode, cur_subtree_i, left_sib_i, dist):
    cur_subtree: LayoutNode = w.children[cur_subtree_i]
    n = cur_subtree_i - left_sib_i
    # Are there intermediate children?
    if n > 1:
        delta = dist / n
        w.children[left_sib_i + 1].shift += delta
        cur_subtree.shift -= delta
        cur_subtree.change -= dist - delta


def next_l_contour(w: LayoutNode):
    return w.firstChild if w.hasChildren else w.lThr


def next_r_contour(w: LayoutNode):
    return w.lastChild if w.hasChildren else w.rThr


def set_l_thr(w: LayoutNode, i, lContour, lSumMods):
    first_child = w.firstChild
    lExt = first_child.lExt
    cur_subtree: LayoutNode = w.children[i]
    lExt.lThr = lContour
    # Change relX so that the sum of modifier after following thread is correct.
    diff = lSumMods - lContour.relX - first_child.lExtRelX
    lExt.relX += diff
    # Change preliminary x coordinate so that the node does not move.
    lExt.prelim -= diff
    # Update extreme node and its sum of modifiers.
    first_child.lExt = cur_subtree.lExt
    first_child.lExtRelX = cur_subtree.lExtRelX


# Mirror image of set_l_thr.
def set_r_thr(w: LayoutNode, i, rContour, rSumMods):
    cur_subtree: LayoutNode = w.children[i]
    rExt = cur_subtree.rExt
    lSib = w.children[i - 1]
    rExt.rThr = rContour
    diff = rSumMods - rContour.relX - cur_subtree.rExtRelX
    rExt.relX += diff
    rExt.prelim -= diff
    cur_subtree.rExt = lSib.rExt
    cur_subtree.rExtRelX = lSib.rExtRelX


# Separates the latest child from its previous sibling
def separate(w: LayoutNode, i, lows):
    lSib: LayoutNode = w.children[i - 1]
    cur_subtree: LayoutNode = w.children[i]
    rContour = lSib
    rSumMods = lSib.relX
    lContour = cur_subtree
    lSumMods = cur_subtree.relX
    is_first = True
    while rContour and lContour:
        if rContour.bottom > lows.lowY:
            lows = lows.next
        # How far to the left of the right side of rContour is the left side
        # of lContour? First compute the center-to-center distance, then add
        # the "spacing"
        dist = (
            (rSumMods + rContour.prelim)
            - (lSumMods + lContour.prelim)
            + rContour.xSize / 2
            + lContour.xSize / 2
            + rContour.spacing(lContour)
        )
        if dist > 0 or (dist < 0 and is_first):
            lSumMods += dist
            # Move subtree by changing relX.
            move_subtree(cur_subtree, dist)
            distribute_extra(w, i, lows.index, dist)
        is_first = False
        # Advance highest node(s) and sum(s) of modifiers
        rightBottom = rContour.bottom
        leftBottom = lContour.bottom
        if rightBottom <= leftBottom:
            rContour = next_r_contour(rContour)
            if rContour:
                rSumMods += rContour.relX
        if rightBottom >= leftBottom:
            lContour = next_l_contour(lContour)
            if lContour:
                lSumMods += lContour.relX
    # Set threads and update extreme nodes. In the first case, the
    # current subtree is taller than the left siblings.
    if not rContour and lContour:
        set_l_thr(w, i, lContour, lSumMods)
    # In the next case, the left siblings are taller than the current subtree
    elif rContour and not lContour:
        set_r_thr(w, i, rContour, rSumMods)


# Resolves the relative coordinate properties - relX and prelim --
# to set the final, absolute x coordinate for each node. This also sets
# `prelim` to 0, so that `relX` for each node is its x-coordinate relative
# to its parent.
def resolve_x(w: LayoutNode, prev_sum=None, parent_x=None):
    # A call to resolveX without arguments is assumed to be for the root of
    # the tree. This will set the root's x-coord to zero.
    if prev_sum is None:
        prev_sum = -w.relX - w.prelim
        parent_x = 0
    sum_ = prev_sum + w.relX
    w.relX = sum_ + w.prelim - parent_x
    w.prelim = 0
    w.x = parent_x + w.relX
    for k in w.children or []:
        resolve_x(k, sum_, w.x)
    return w


class DrawNode(object):
    def __init__(self, drawable, children=None):
        self.drawable = drawable
        self.children = [] if children is None else [x for x in children]


def do_layout(tree: DrawNode):
    def _convert_to_n(node: DrawNode):
        children = [_convert_to_n(c) for c in node.children]
        return LayoutNode(
            size=(node.drawable.bounds.width, node.drawable.bounds.height),
            children=children or None,
            data=node,
        )

    def print_xy(tree: LayoutNode, indent=0):
        print(" " * indent, tree.size)
        for child in tree.children or []:
            print_xy(child, indent + 2)

    n = _convert_to_n(tree)
    layout_children(n)
    resolve_x(n)
    return n
