from typing import List, Any, Sequence

import iceberg as ice
import dataclasses
import typing
from functools import reduce

_FAMILY_MONO = "Fira Mono"
_FAMILY = "IBM Plex Sans"


def prob_colormap(prob, base_color="#3498db", background_color=ice.Colors.WHITE):
    color = ice.Color.from_hex(base_color)
    r, g, b = color.r, color.g, color.b
    br, bg, bb = background_color.r, background_color.g, background_color.b
    r = float(r * prob + br * (1 - prob))
    g = float(g * prob + bg * (1 - prob))
    b = float(b * prob + bb * (1 - prob))
    return ice.Color(r, g, b)


def copy_with_mod(scene, span_settings, a_type=None):
    # Recursively walk through the scene graph and interpolate between the two scenes.
    # Use the fact that everything is a dataclass, so we can use dataclasses.asdict
    # to get a dictionary representation of the scene.

    if scene is None:
        return None

    a_hint = a_type if a_type is not None else None

    a_type = type(scene) if a_type is None else a_type
    a_origin = typing.get_origin(a_type)
    a_type = a_origin if a_origin is not None else a_type

    if a_type == typing.Union or a_type == typing.Optional or a_type == Ellipsis:
        a_type = type(scene)

    if issubclass(a_type, ice.Drawable):
        fieldsA = dataclasses.fields(scene)

        new_scene_fields = {}

        for field in fieldsA:
            fieldA_value = getattr(scene, field.name)

            new_scene_fields[field.name] = copy_with_mod(
                fieldA_value,
                span_settings,
                a_type=field.type,
            )

        if isinstance(scene, EasyDrawableNode):
            if scene.metas and span_settings:
                new_settings = reduce(
                    lambda x, y: x or y,
                    (
                        span_settings.get(m.start_pos)
                        or span_settings.get((m.start_pos, m.end_pos))
                        for m in scene.metas
                    ),
                )
                if new_settings:
                    new_scene_fields = new_scene_fields | new_settings

        return scene.__class__.from_fields(**new_scene_fields)
    # Sequence captures a lot, excluding str is a hack for now.
    elif issubclass(a_type, (list, tuple, Sequence)) and not issubclass(a_type, str):
        sub_type = [None] * len(scene)
        if a_hint:
            if len(a_hint.__args__) == len(scene):
                sub_type = a_hint.__args__
            elif len(a_hint.__args__) == 1:
                sub_type = [a_hint.__args__[0]] * len(scene)

        rv = [
            copy_with_mod(a, span_settings, a_type=s) for a, s in zip(scene, sub_type)
        ]
        if isinstance(scene, tuple):
            return tuple(rv)
        return rv

    return scene


class EasyDrawableNode(ice.DrawableWithChild):
    text_label: str = None
    sub_child: ice.Drawable = None
    background_color: ice.Color = ice.Colors.WHITE
    text_color: ice.Color = ice.Colors.BLACK
    border_color: ice.Color = ice.Colors.BLACK
    border_thickness: float = 1
    border_radius: float = 5
    metas: List[Any] = None
    font_size: float = 12
    label_padding: float = 5
    circle: bool = False
    padding_anchor: bool = False

    def setup(self):
        label = self.sub_child or ice.Text(
            text=self.text_label,
            font_style=ice.FontStyle(
                family=_FAMILY_MONO, size=self.font_size, color=self.text_color
            ),
        )
        if not self.circle:
            rectangle = ice.Rectangle(
                label.pad(self.label_padding).bounds,
                border_color=self.border_color if self.border_thickness > 0 else None,
                border_thickness=self.border_thickness,
                fill_color=self.background_color,
                border_radius=self.border_radius,
            )
            rectangle = rectangle.add_centered(label)
        else:
            bnds = label.pad(self.label_padding).bounds
            w = max(bnds.width, bnds.height)
            circle = ice.Ellipse(
                rectangle=ice.Bounds.from_size(w, w),
                border_color=self.border_color if self.border_thickness > 0 else None,
                border_thickness=self.border_thickness,
                fill_color=self.background_color,
            )
            rectangle = circle.add_centered(label)

        if self.padding_anchor:
            rectangle = ice.Anchor([rectangle, label], anchor_index=1)

        self.set_child(rectangle)


def typewriter(text_list, fs=11):
    cursor_x = 0
    cursor_y = 0

    dummy = EasyDrawableNode(
        text_label=" ",
        font_size=fs,
        border_thickness=0,
        label_padding=0,
    )
    line_height = dummy.bounds.height + 2

    rv = []

    for text in text_list:
        if isinstance(text, tuple):
            text, meta = text
        else:
            meta = None

        if text == "\n":
            cursor_x = 0
            cursor_y += line_height
            continue

        math_tex = None
        if text.startswith("$"):
            text = text[1:-1]
            math_tex = ice.MathTex(text).scale(1)
            math_tex = math_tex.move(0, -2).crop(math_tex.bounds)

        node = EasyDrawableNode(
            text_label=text,
            sub_child=math_tex,
            metas=[meta] if meta else [],
            font_size=fs,
            border_thickness=0,
            label_padding=3,
            background_color=None,
            padding_anchor=True,
        ).move_to(cursor_x, cursor_y, ice.Corner.BOTTOM_LEFT)

        cursor_x += node.bounds.width

        rv.append(node)

    return ice.Compose(rv)
