from __future__ import annotations

import math
import re
from typing import Literal, Tuple

import numpy as np
import plotly.colors as pc


Quadrant = Literal["full", "BL", "BR", "TL", "TR"]

_QUAD_EXPR = {
    "BL": "a+b",
    "TL": "a+b",
    "BR": "b-a",
    "TR": "b-a",
}

__all__ = [
    "colour_quad_mul_f",
    "colour_quad_mod_g",
    "colour_quad_a_only",
    "colour_quad_b_only",
    "colour_pair_mod_g",
]

_DEFAULT = "Viridis"


def _base_val(a: np.ndarray, b: np.ndarray, tag: str) -> np.ndarray:
    if tag in ("BL", "TL"):
        return a + b
    if tag in ("BR", "TR"):
        return b - a
    raise ValueError("tag must be one of BL/BR/TL/TR")


def step_size(f: int, p: int) -> int:
    g = math.gcd(f, p)
    n = p // g
    return pow(f // g, -1, n)


def _ensure_int_freq(f):
    if isinstance(f, (int, np.integer)):
        return int(f)
    if isinstance(f, (float, np.floating)):
        if not np.isfinite(f):
            raise TypeError(f"freq f={f!r} is not finite")
        return int(round(float(f)))
    if isinstance(f, str):
        m = re.search(r"-?\d+", f)
        if m:
            return int(m.group(0))
        raise TypeError(f"freq f={f!r} contains no digits")
    if f is None:
        raise TypeError("freq f is None")
    try:
        return int(f)
    except Exception:
        raise TypeError(f"freq f={f!r} cannot be interpreted as an integer")


def _resolve_scale(scale_name_or_list):
    if isinstance(scale_name_or_list, (list, tuple)):
        return scale_name_or_list

    name = str(scale_name_or_list)

    if name in pc.PLOTLY_SCALES:
        return pc.PLOTLY_SCALES[name]

    try:
        import plotly.express as px

        for ns in (
            px.colors.sequential,
            px.colors.diverging,
            px.colors.cyclical,
            px.colors.qualitative,
        ):
            if hasattr(ns, name):
                return getattr(ns, name)

        if name.lower() in ("oranges", "orrd"):
            return px.colors.sequential.Oranges
        if name.lower() in ("ylorrd",):
            return px.colors.sequential.YlOrRd
    except Exception:
        pass

    raise KeyError(f"Unknown colorscale name: {name}")


def _interp(scale_name_or_list, t: float) -> str:
    scale = _resolve_scale(scale_name_or_list)
    return pc.sample_colorscale(scale, [t])[0]


def _interp_rainbow_red_orange(t: float) -> str:
    return _interp("Viridis", 0.55 + 0.45 * t)


def _interp_viridis(t: float) -> str:
    return _interp("Viridis", 0.0 + 0.45 * t)


def build_split_scale_red_orange(g: int) -> list[tuple[float, str]]:
    if g <= 1:
        return [(0.0, _interp("Viridis", 0.0)), (1.0, _interp("Viridis", 1.0))]
    v = [(0.45 * i / (g - 1), _interp_viridis(i / (g - 1))) for i in range(g)]
    r = [(0.55 + 0.45 * i / (g - 1), _interp_rainbow_red_orange(i / (g - 1))) for i in range(g)]
    return v + r


def build_vi_scale(g: int) -> list[tuple[float, str]]:
    return [(i / (g - 1), _interp("Viridis", 0.0 + 0.45 * i / (g - 1))) for i in range(g)]


def build_ro_scale(g: int) -> list[tuple[float, str]]:
    return [(i / (g - 1), _interp("Viridis", 0.55 + 0.45 * i / (g - 1))) for i in range(g)]


def colour_quad_mod_g_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    f = _ensure_int_freq(f)

    if tag == "full":
        A, B = a, b
        top, right = A >= p, B >= p
        out = np.empty_like(a)
        for quad, m in {
            "BL": (~top) & (~right),
            "BR": (~top) & right,
            "TL": top & (~right),
            "TR": top & right,
        }.items():
            out[m], _, _, _ = colour_quad_mod_g_no_fb(a[m] % p, b[m] % p, p, f, quad)
        caption = f"(a+/-b) mod {g} [g=p//gcd({f},{p})]"
        return out, caption, int(out.max()) + 1, _DEFAULT

    base = _base_val(a, b, tag)
    colour = base % g
    expr = _QUAD_EXPR[tag]
    caption = f"({expr}) mod {g}"
    return colour, caption, int(colour.max()) + 1, _DEFAULT


def colour_quad_mul_f(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    f = _ensure_int_freq(f)

    if tag == "full":
        A, B = a, b
        top = A >= p
        right = B >= p
        base_raw = np.where(right, B - A, A + B)
        c = (f * base_raw) % p
        mask_front = (~top & ~right) | (top & right)
        colour_idx = np.where(mask_front, c, p + c)
        caption = f"f*(b+/-a) mod {p} (BL/TR -> Viridis dark; TL/BR -> Viridis bright)"
        return colour_idx, caption, int(colour_idx.max()) + 1, build_split_scale_red_orange(p)

    base = _base_val(a, b, tag)
    is_rain = tag in ("BR", "TL")
    colour = (f * base) % p
    expr = _QUAD_EXPR[tag]
    caption = f"{f}*({expr}) mod {p}"
    cmap = build_ro_scale(p) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap


def colour_quad_mod_g(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)

    step = step_size(f, p)
    remap = f if step != 1 else 1

    if tag == "full":
        A, B = a, b
        top = A >= p
        right = B >= p
        base_raw = np.where(right, B - A, A + B)
        c = (remap * base_raw) % g
        mask_front = (~top & ~right) | (top & right)
        colour_idx = np.where(mask_front, c, g + c)
        caption = f"(b+/-a) mod {g} (BL/TR -> Viridis dark; TL/BR -> Viridis bright)"
        return colour_idx, caption, int(colour_idx.max()) + 1, build_split_scale_red_orange(g)

    base = _base_val(a, b, tag)
    is_rain = tag in ("BR", "TL")
    colour = (remap * base) % g
    expr = _QUAD_EXPR[tag]
    caption = f"({expr}) mod {g}"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap


def colour_c_mod_p(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)

    step = step_size(f, p)
    remap = f if step != 1 else 1

    if tag == "full":
        A, B = a, b
        top = A >= p
        right = B >= p

        base_raw = np.where(right, B - A, A + B)
        c = (remap * base_raw) % g

        sub = ((A // p) + (B // p)) % g
        base_idx = c * g + sub

        mask_front = (~top & ~right) | (top & right)
        colour_idx = np.where(mask_front, base_idx, base_idx + g * g)

        n_colours = int(colour_idx.max()) + 1
        caption = f"c=(b+/-a) mod {g}, sub=((a//p)+(b//p)) mod {g}"
        cmap = build_split_scale_red_orange(n_colours)
        return colour_idx, caption, n_colours, cmap

    base = _base_val(a, b, tag)
    is_rain = tag in ("BR", "TL")
    colour = (base) % g
    expr = _QUAD_EXPR[tag]
    caption = f"({expr}) mod {g}"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap


def colour_quad_a_only(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p

    a = np.asarray(a, int)
    b = np.asarray(b, int)
    step = step_size(f, p)
    remap = f if step != 1 else 1

    if tag == "full":
        A, B = a, b
        top = A >= p
        is_rain = top

        base = (remap * A) % g
        colour = np.where(is_rain, g + base, base)
        caption = f"Viridis dark (a<p): a mod {g}; Viridis bright (a>=p): a mod {g}"
        return colour, caption, 2 * g, build_split_scale_red_orange(g)

    is_rain = tag in ("TR", "TL")
    base = (np.asarray(a) * remap) % g
    colour = base
    caption = f"a mod {g} ({'Viridis bright' if is_rain else 'Viridis dark'})"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap


def colour_quad_b_only(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p

    a = np.asarray(a, int)
    b = np.asarray(b, int)
    step = step_size(f, p)
    remap = f if step != 1 else 1

    if tag == "full":
        A, B = a, b
        right = B >= p
        is_rain = right

        base = (remap * B) % g
        colour = np.where(is_rain, g + base, base)
        caption = f"Viridis dark (b<p): b mod {g}; Viridis bright (b>=p): b mod {g}"
        return colour, caption, 2 * g, build_split_scale_red_orange(g)

    is_rain = tag in ("BR", "TR")
    base = (remap * np.asarray(b)) % g
    colour = base
    caption = f"b mod {g} ({'Viridis bright' if is_rain else 'Viridis dark'})"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap


def colour_quad_a_only_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: str,
):
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p

    a = np.asarray(a, int)
    b = np.asarray(b, int)

    if tag == "full":
        A = a % p
        colour = A % g
    else:
        colour = a % g

    caption = f"a mod {g}"
    pbar = int(g)
    return colour, caption, pbar, _DEFAULT


def colour_quad_b_only_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: str,
):
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p

    a = np.asarray(a, int)
    b = np.asarray(b, int)

    if tag == "full":
        B = b % p
        colour = B % g
    else:
        colour = b % g

    caption = f"b mod {g}"
    pbar = int(g)
    return colour, caption, pbar, _DEFAULT


def _order_by_step(local_vals: np.ndarray, g: int, d: int) -> np.ndarray:
    resid = np.asarray(local_vals, int) % g
    if resid.size == 0:
        return np.arange(0, 0, dtype=int)

    r0 = int(np.min(resid))
    seq = [(r0 + t * d) % g for t in range(g)]
    pos = {r: i for i, r in enumerate(seq)}

    order_keys = np.vectorize(pos.get)(resid)
    return np.argsort(order_keys, kind="stable")


def lines_a_mod_g_step(a_vals, b_vals, p, g, d):
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    out = []

    for b_fix in range(g):
        idx = np.where((b_vals == b_fix) & (a_vals >= 0) & (a_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(a_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "blue", f"BL_b{b_fix}"))

    for b_fix in range(g):
        idx = np.where((b_vals == b_fix) & (a_vals >= p) & (a_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((a_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "blue", f"TL_b{b_fix}"))

    for b_fix in range(p, p + g):
        idx = np.where((b_vals == b_fix) & (a_vals >= 0) & (a_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(a_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "red", f"BR_b{b_fix}"))

    for b_fix in range(p, p + g):
        idx = np.where((b_vals == b_fix) & (a_vals >= p) & (a_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((a_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "red", f"TR_b{b_fix}"))

    return out


def lines_b_mod_g_step(a_vals, b_vals, p, g, d):
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    out = []

    for a_fix in range(g):
        idx = np.where((a_vals == a_fix) & (b_vals >= 0) & (b_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(b_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "blue", f"BL_a{a_fix}"))

    for a_fix in range(g):
        idx = np.where((a_vals == a_fix) & (b_vals >= p) & (b_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((b_vals[idx] - p) % g, g, d)
            out.append((idx[order], "solid", "red", f"BR_a{a_fix}"))

    for a_fix in range(p, p + g):
        idx = np.where((a_vals == a_fix) & (b_vals >= 0) & (b_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(b_vals[idx] % g, g, d)
            out.append((idx[order], "dash", "blue", f"TL_a{a_fix}"))

    for a_fix in range(p, p + g):
        idx = np.where((a_vals == a_fix) & (b_vals >= p) & (b_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((b_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "red", f"TR_a{a_fix}"))

    return out


def lines_c_mod_g_step(a_vals, b_vals, p, g, d):
    A = np.asarray(a_vals, int)
    B = np.asarray(b_vals, int)
    top, right = (A >= p), (B >= p)

    mask_BL = (~top) & (~right) & (A < g) & (B < g)
    mask_TR = (top) & (right) & (A >= p) & (A < p + g) & (B >= p) & (B < p + g)
    mask_BR = (~top) & (right) & (A < g) & (B >= p) & (B < p + g)
    mask_TL = (top) & (~right) & (A >= p) & (A < p + g) & (B < g)

    out = []
    for r in range(g):
        idx = np.where(mask_BL & (((B % g) + (A % g)) % g == r))[0]
        if idx.size >= 2:
            order = _order_by_step(A[idx] % g, g, d)
            out.append((idx[order], "solid", "blue", f"BL_r{r}"))

        idx = np.where(mask_TR & (((B - A) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(((A[idx] - p) % g), g, d)
            out.append((idx[order], "solid", "red", f"TR_r{r}"))

        idx = np.where(mask_BR & (((B - A) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(A[idx] % g, g, d)
            out.append((idx[order], "solid", "green", f"BR_r{r}"))

        idx = np.where(mask_TL & ((((B % g) + ((A - p) % g)) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(((A[idx] - p) % g), g, d)
            out.append((idx[order], "solid", "purple", f"TL_r{r}"))

    return out


def _build_discrete_scale(colors: list[str]) -> list[tuple[float, str]]:
    k = len(colors)
    if k == 0:
        return [(0.0, "#000000"), (1.0, "#000000")]
    scale = []
    for i, c in enumerate(colors):
        lo = i / k
        hi = (i + 1) / k
        scale.append((lo, c))
        scale.append((hi, c))
    return scale


_MOD2_COLORS = ["#3B528B", "#21918C", "#5EC962", "#FDE725"]
_FB2_COLORS = ["#2C728E", "#95D840"]


def colour_pair_mod2(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)

    A, B = a, b
    am2 = (A & 1)
    bm2 = (B & 1)

    colour = (am2 << 1) | bm2
    caption = "(a mod 2, b mod 2) in {00,01,10,11} -> {0,1,2,3}"
    pbar = 4
    cmap = _build_discrete_scale(_MOD2_COLORS)
    return colour, caption, pbar, cmap


def colour_front_back_by_c(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)

    if tag == "full":
        A, B = a, b
        top = (A >= p)
        right = (B >= p)
        is_front = ((~top) & (~right)) | (top & right)
        colour = np.where(is_front, 0, 1)
        caption = "front (BL/TR) vs back (TL/BR) quadrants"
        pbar = 2
        cmap = _build_discrete_scale(_FB2_COLORS)
        return colour, caption, pbar, cmap

    if tag in ("BL", "TR"):
        cls = 0
    elif tag in ("TL", "BR"):
        cls = 1
    else:
        raise ValueError("tag must be one of BL/BR/TL/TR or 'full'")

    colour = np.full_like(a, fill_value=cls, dtype=int)
    caption = f"{'front' if cls == 0 else 'back'} quadrant ({tag})"
    pbar = 2
    cmap = _build_discrete_scale(_FB2_COLORS)
    return colour, caption, pbar, cmap
