from td.environments import Environment
from lark.tree import Tree
import iceberg as ice
from .helpers import prob_colormap


def get_text(text, x, y, font_style):
    return ice.Text(text=text, font_style=font_style).move_to(
        x, y, corner=ice.Corner.BOTTOM_LEFT
    )


def text_cursor(text_obj):
    return text_obj.bounds.corners[ice.Corner.BOTTOM_RIGHT]


def defer_combine(
    objects,
    font_style,
    edit_probs_lookup,
    draw_fn,
    line_spacing=5,
    cursor_x=0,
    cursor_y=0,
    node=None,
):
    line_height = get_text("A", 0, 0, font_style).bounds.height + line_spacing
    all_objs = []
    start_x = cursor_x
    for obj in objects:
        if obj == "\n":
            cursor_x = start_x
            cursor_y += line_height
            continue
        elif isinstance(obj, str):
            obj = get_text(obj, cursor_x, cursor_y, font_style)
            cursor_x, cursor_y = text_cursor(obj)
        elif isinstance(obj, Tree):
            obj = draw_fn(obj, cursor_x, cursor_y)
            cursor_x, cursor_y = text_cursor(obj)
        else:
            raise ValueError(f"Unexpected object: {obj}")

        all_objs.append(obj)

    start = node.meta.start_pos
    end = node.meta.end_pos

    composed = ice.Compose(all_objs)

    if (start, end) in edit_probs_lookup:
        prob = edit_probs_lookup[(start, end)]
        return ice.Anchor(
            [
                ice.Rectangle(
                    composed.bounds,
                    fill_color=prob_colormap(prob),
                    dont_modify_bounds=True,
                    border_radius=10,
                ),
                composed,
            ],
            anchor_index=1,
        )

    return composed


def draw_code_csg2da(
    env: Environment,
    expression: str,
    edit_probs=None,
    font_family="Fira Mono",
    font_size=24,
):
    edit_probs_lookup = {(s, e): p for s, e, p in edit_probs} if edit_probs else {}
    tree = env.grammar.parse(expression)
    font_style = ice.FontStyle(
        family=font_family, size=font_size, color=ice.Colors.BLACK
    )

    def draw(node, cursor_x=0, cursor_y=0):
        if node.data == "s":
            return draw(node.children[0], cursor_x, cursor_y)

        if node.data == "binop":
            op, left, right = node.children
            return defer_combine(
                ["(", op, "\n", "  ", left, "\n", "  ", right, "\n", ")"],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif node.data == "circle":
            x, y, r = node.children
            return defer_combine(
                ["Circle(x=", x, ", y=", y, ", r=", r, ")"],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif node.data == "quad":
            x, y, w, h, angle = node.children
            return defer_combine(
                ["Quad(x=", x, ", y=", y, ", w=", w, ", h=", h, ", angle=", angle, ")"],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif isinstance(node, Tree):
            _literal_map = {
                "zero": "0",
                "one": "2",
                "two": "4",
                "three": "6",
                "four": "8",
                "five": "10",
                "six": "12",
                "seven": "14",
                "eight": "16",
                "nine": "18",
                "ten": "20",
                "eleven": "22",
                "twelve": "24",
                "thirteen": "26",
                "fourteen": "28",
                "fifteen": "30",
                "sixteen": "32",
                "zerodeg": "0°",
                "onedeg": "45°",
                "twodeg": "90°",
                "threedeg": "135°",
                "fourdeg": "180°",
                "fivedeg": "225°",
                "sixdeg": "270°",
                "sevendeg": "315°",
                "add": "Add",
                "subtract": "Subtract",
                "intersect": "Intersect",
            }

            return defer_combine(
                [_literal_map.get(node.data, node.data)],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )

    return draw(tree)


def draw_code_tinysvg(
    env: Environment,
    expression: str,
    edit_probs=None,
    font_family="Fira Mono",
    font_size=24,
):
    edit_probs_lookup = {(s, e): p for s, e, p in edit_probs} if edit_probs else {}
    tree = env.grammar.parse(expression)
    font_style = ice.FontStyle(
        family=font_family, size=font_size, color=ice.Colors.BLACK
    )

    def draw(node, cursor_x=0, cursor_y=0):
        if node.data == "s":
            return draw(node.children[0], cursor_x, cursor_y)

        if node.data == "arrange":
            direction, left, right, gap = node.children
            return defer_combine(
                [
                    "(Arrange ",
                    direction,
                    " gap=",
                    gap,
                    "\n",
                    "  ",
                    left,
                    "\n",
                    "  ",
                    right,
                    "\n",
                    ")",
                ],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif node.data == "move":
            drawable, x, y, negx, negy = node.children
            return defer_combine(
                [
                    "(Move ",
                    "x=",
                    negx,
                    x,
                    " ",
                    "y=",
                    negy,
                    y,
                    "\n",
                    "  ",
                    drawable,
                    "\n",
                    ")",
                ],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif node.data == "rect" or node.data == "ellipse":
            w, h, fill_color, stroke_color, stroke_width = node.children
            opname = "Rect" if node.data == "rect" else "Ball"
            return defer_combine(
                [
                    f"{opname}(w=",
                    w,
                    ", h=",
                    h,
                    ", fill=",
                    fill_color,
                    ", border=(",
                    stroke_color,
                    ",",
                    stroke_width,
                    ")",
                    ")",
                ],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )
        elif isinstance(node, Tree):
            _literal_map = {
                "v": "v",
                "h": "h",
                "red": "red",
                "green": "green",
                "blue": "blue",
                "yellow": "yellow",
                "purple": "purple",
                "orange": "orange",
                "black": "black",
                "white": "white",
                "none": "none",
                "zero": "0",
                "one": "1",
                "two": "2",
                "three": "3",
                "four": "4",
                "five": "5",
                "six": "6",
                "seven": "7",
                "eight": "8",
                "nine": "9",
                "true": "-",
                "false": "+",
            }

            return defer_combine(
                [_literal_map.get(node.data, node.data)],
                font_style=font_style,
                edit_probs_lookup=edit_probs_lookup,
                draw_fn=draw,
                cursor_x=cursor_x,
                cursor_y=cursor_y,
                node=node,
            )

    return draw(tree)


code_drawers = {
    "csg2da": draw_code_csg2da,
    "tinysvg": draw_code_tinysvg,
}
