from typing import Any, Dict, Tuple
from lark import Transformer, v_args

import iceberg as ice
from td.environments import Environment
from .helpers import EasyDrawableNode, copy_with_mod, typewriter
from .tree_layout import DrawNode, do_layout, LayoutNode


class DrawTransformer(Transformer):
    def __init__(self, code_tree: "CodeTree", visit_tokens=True):
        super().__init__(visit_tokens)
        self.code_tree = code_tree


class BlockworldToDraw(DrawTransformer):
    @v_args(meta=True)
    def box(self, meta, _):
        return DrawNode(EasyDrawableNode(text_label="Box", metas=[meta]))

    @v_args(meta=True)
    def arrange(self, meta, children):
        direction_drawable, left, right = children

        arrange_drawable = EasyDrawableNode(
            text_label="Arrange",
            metas=[meta],
        )
        side_by_side = ice.Arrange(
            [arrange_drawable, direction_drawable],
            gap=5,
        )

        with side_by_side:
            line = ice.Line(
                start=arrange_drawable.relative_bounds.corners[ice.Corner.MIDDLE_RIGHT],
                end=direction_drawable.relative_bounds.corners[ice.Corner.MIDDLE_LEFT],
                path_style=self.code_tree.line_style,
            )

        side_by_side += line

        return DrawNode(
            side_by_side,
            children=[left, right],
        )

    @v_args(meta=True)
    def s(self, meta, children):
        c = children[0]

        if isinstance(c, DrawNode):
            c.drawable.metas = [meta]

        return c

    @v_args(meta=True)
    def v(self, meta, _):
        return EasyDrawableNode(
            text_label="v",
            metas=[meta],
        )

    @v_args(meta=True)
    def h(self, meta, _):
        return EasyDrawableNode(
            text_label="h",
            metas=[meta],
        )


class CSG2DADraw(DrawTransformer):
    @v_args(meta=True)
    def quad(self, meta, children):
        x, y, w, h, angle_degrees = children

        # all_children = [
        #     (f"x={x[0]*2}", meta),
        #     (f"y={y[0]*2}", meta),
        #     (f"w={w[0]*2}", meta),
        #     (f"h={h[0]*2}", meta),
        #     (f"θ={angle_degrees[0]}°", meta),
        # ]

        # all_children = [
        #     DrawNode(
        #         EasyDrawableNode(
        #             text_label=text,
        #             metas=[meta],
        #             font_size=9,
        #         )
        #     )
        #     for text, meta in all_children
        # ]

        # return DrawNode(
        #     EasyDrawableNode(
        #         text_label="Quad",
        #         metas=[meta],
        #     ),
        #     children=all_children,
        # )

        inner = typewriter(
            [
                "Quad(",
                "\n",
                "  ",
                (f"x={x[0]*2}", x[1]),
                "\n",
                "  ",
                (f"y={y[0]*2}", y[1]),
                "\n",
                "  ",
                (f"w={w[0]*2}", w[1]),
                "\n",
                "  ",
                (f"h={h[0]*2}", h[1]),
                "\n",
                "  ",
                "$\\theta$",
                (f"={angle_degrees[0]}°", angle_degrees[1]),
                "\n",
                ")",
            ]
        )

        main_node = EasyDrawableNode(
            sub_child=inner,
            metas=[meta],
        )

        return DrawNode(main_node)

    @v_args(meta=True)
    def circle(self, meta, children):
        r, x, y = children

        # all_children = [
        #     (f"x={x[0]*2}", meta),
        #     (f"y={y[0]*2}", meta),
        #     (f"r={r[0]*2}", meta),
        # ]

        # all_children = [
        #     DrawNode(
        #         EasyDrawableNode(
        #             text_label=text,
        #             metas=[meta],
        #             font_size=9,
        #         )
        #     )
        #     for text, meta in all_children
        # ]

        # return DrawNode(
        #     EasyDrawableNode(
        #         text_label="Quad",
        #         metas=[meta],
        #     ),
        #     children=all_children,
        # )

        inner = typewriter(
            [
                "Circle(",
                "\n",
                "  ",
                (f"x={x[0]*2}", x[1]),
                "\n",
                "  ",
                (f"y={y[0]*2}", y[1]),
                "\n",
                "  ",
                (f"r={r[0]*2}", r[1]),
                "\n",
                ")",
            ]
        )

        main_node = EasyDrawableNode(
            sub_child=inner,
            metas=[meta],
        )

        return DrawNode(main_node)

    @v_args(meta=True)
    def binop(self, meta, children):
        (op, op_meta), left, right = children

        left_drawable = left
        right_drawable = right

        myself = EasyDrawableNode(
            sub_child=ice.MathTex(op).scale(1.5),
            metas=[op_meta],
            circle=True,
        )
        return DrawNode(myself, children=[left_drawable, right_drawable])

    @v_args(meta=True)
    def add(self, meta, children):
        return ("+", meta)

    @v_args(meta=True)
    def subtract(self, meta, children):
        return ("-", meta)

    @v_args(meta=True)
    def s(self, meta, children):
        c = children[0]

        if isinstance(c, DrawNode):
            c.drawable.metas.append(meta)

        return c

    @v_args(meta=True)
    def zero(self, meta, _):
        return (0, meta)

    @v_args(meta=True)
    def one(self, meta, _):
        return (1, meta)

    @v_args(meta=True)
    def two(self, meta, _):
        return (2, meta)

    @v_args(meta=True)
    def three(self, meta, _):
        return (3, meta)

    @v_args(meta=True)
    def four(self, meta, _):
        return (4, meta)

    @v_args(meta=True)
    def five(self, meta, _):
        return (5, meta)

    @v_args(meta=True)
    def six(self, meta, _):
        return (6, meta)

    @v_args(meta=True)
    def seven(self, meta, _):
        return (7, meta)

    @v_args(meta=True)
    def eight(self, meta, _):
        return (8, meta)

    @v_args(meta=True)
    def nine(self, meta, _):
        return (9, meta)

    @v_args(meta=True)
    def ten(self, meta, _):
        return (10, meta)

    @v_args(meta=True)
    def eleven(self, meta, _):
        return (11, meta)

    @v_args(meta=True)
    def twelve(self, meta, _):
        return (12, meta)

    @v_args(meta=True)
    def thirteen(self, meta, _):
        return (13, meta)

    @v_args(meta=True)
    def fourteen(self, meta, _):
        return (14, meta)

    @v_args(meta=True)
    def fifteen(self, meta, _):
        return (15, meta)

    @v_args(meta=True)
    def zerodeg(self, meta, _):
        return (0, meta)

    @v_args(meta=True)
    def onedeg(self, meta, _):
        return (45, meta)

    @v_args(meta=True)
    def twodeg(self, meta, _):
        return (90, meta)

    @v_args(meta=True)
    def threedeg(self, meta, _):
        return (135, meta)

    @v_args(meta=True)
    def fourdeg(self, meta, _):
        return (180, meta)

    @v_args(meta=True)
    def fivedeg(self, meta, _):
        return (225, meta)

    @v_args(meta=True)
    def sixdeg(self, meta, _):
        return (270, meta)

    @v_args(meta=True)
    def sevendeg(self, meta, _):
        return (315, meta)


_transformer_classes = {
    "blockworld": BlockworldToDraw,
    "csg2da": CSG2DADraw,
}


class CodeTree(ice.DrawableWithChild):
    expresion: str
    env: Environment
    x_unit_size: float = 1
    y_unit_size: float = 2
    radius: float = 0
    line_style: ice.PathStyle = ice.PathStyle(color=ice.Colors.BLACK)
    span_settings: Dict[Tuple[int, int] | int, Dict[str, Any]] = None

    def setup(self):
        lark_tree = self.env.grammar.parse(self.expresion)
        transformer = _transformer_classes[self.env.name()](self)
        root = transformer.transform(lark_tree)

        # draw_tree = buchheim(root)
        draw_tree = do_layout(root)

        objects = []
        lines = []

        def walk(node: LayoutNode, parent=None):
            x, y = node.x * self.x_unit_size, node.y * self.y_unit_size
            drawable = node.data.drawable
            # x += drawable.bounds.width / 2
            draw_node = drawable.move_to(x, y, ice.Corner.TOP_MIDDLE)
            y += drawable.bounds.height / 2
            objects.append(draw_node)

            if parent:
                px, py = parent

                # Find normal vector.
                dx, dy = x - px, y - py
                length = (dx**2 + dy**2) ** 0.5
                dx, dy = dx / length, dy / length

                # Push px and py out by radius.
                px, py = px + dx * self.radius, py + dy * self.radius

                # Push x and y in by radius.
                mx, my = x - dx * self.radius, y - dy * self.radius

                if not node.hasChildren:
                    mx, my = draw_node.bounds.corners[ice.Corner.TOP_MIDDLE]

                line = ice.Line((px, py), (mx, my), path_style=self.line_style)
                lines.append(line)

            for child in node.children or []:
                walk(child, parent=(x, y))

        walk(draw_tree)
        self._final_draw_tree = copy_with_mod(
            ice.Compose(lines + objects),
            self.span_settings,
        )
        self.set_child(self._final_draw_tree)

    def node_reference(self, start_pos, end_pos=None) -> EasyDrawableNode:
        def f(x):
            return hasattr(x, "metas") and any(
                start_pos == meta.start_pos
                and (end_pos is None or end_pos == meta.end_pos)
                for meta in x.metas
            )

        rv = self._final_draw_tree.find_all(f)
        return rv[0] if rv else None
