from __future__ import annotations

import math
from typing import Literal, Tuple

import numpy as np
import plotly.colors as pc
import re


Quadrant = Literal["full", "BL", "BR", "TL", "TR"]

_QUAD_EXPR = {
    "BL": "a+b",
    "TL": "a+b",
    "BR": "b-a",
    "TR": "b-a",
}

__all__ = [
    "colour_quad_mul_f",
    "colour_quad_mod_g",
    "colour_quad_a_only",
    "colour_quad_b_only",
    "colour_pair_mod_g",
]

_DEFAULT = "Viridis"

def _base_val(a: np.ndarray, b: np.ndarray, tag: str) -> np.ndarray:
    if tag in ("BL", "TL"):
        return a + b
    elif tag in ("BR", "TR"):
        return b - a
    raise ValueError("tag must be one of BL/BR/TL/TR")

def step_size(f: int, p: int) -> int:
    g = math.gcd(f, p)
    n = p // g
    return pow(f // g, -1, n)

def _ensure_int_freq(f):
    if isinstance(f, (int, np.integer)):
        return int(f)
    if isinstance(f, (float, np.floating)):
        if not np.isfinite(f):
            raise TypeError(f"freq f={f!r} is not finite")
        return int(round(float(f)))
    if isinstance(f, str):
        m = re.search(r"-?\d+", f)
        if m:
            return int(m.group(0))
        raise TypeError(f"freq f={f!r} contains no digits")
    if f is None:
        raise TypeError("freq f is None")
    try:
        return int(f)
    except Exception:
        raise TypeError(f"freq f={f!r} cannot be interpreted as an integer")

def _resolve_scale(scale_name_or_list):
    if isinstance(scale_name_or_list, (list, tuple)):
        return scale_name_or_list
    name = str(scale_name_or_list)
    if name in pc.PLOTLY_SCALES:
        return pc.PLOTLY_SCALES[name]
    try:
        import plotly.express as px
        for ns in (px.colors.sequential, px.colors.diverging, px.colors.cyclical, px.colors.qualitative):
            if hasattr(ns, name):
                return getattr(ns, name)
        if name.lower() in ("oranges", "orrd"):
            return px.colors.sequential.Oranges
        if name.lower() in ("ylorrd",):
            return px.colors.sequential.YlOrRd
    except Exception:
        pass
    raise KeyError(f"Unknown colorscale name: {name}")

def _interp(scale_name_or_list, t: float) -> str:
    scale = _resolve_scale(scale_name_or_list)
    return pc.sample_colorscale(scale, [t])[0]

def _interp_rainbow_red_orange(t: float) -> str:
    return _interp("Viridis", 0.55 + 0.45 * t)

def _interp_viridis(t: float) -> str:
    return _interp("Viridis", 0.0 + 0.45 * t)

def build_split_scale_red_orange(g: int) -> list[tuple[float,str]]:
    if g <= 1:
        return [(0.0, _interp("Viridis", 0.0)), (1.0, _interp("Viridis", 1.0))]
    v = [(0.45 * i / (g - 1), _interp_viridis(i / (g - 1))) for i in range(g)]
    r = [(0.55 + 0.45 * i / (g - 1), _interp_rainbow_red_orange(i / (g - 1))) for i in range(g)]
    return v + r

def build_vi_scale(g: int) -> list[tuple[float, str]]:
    return [(i/(g-1), _interp("Viridis", 0.0 + 0.45*i/(g-1)))
            for i in range(g)]

def build_ro_scale(g: int) -> list[tuple[float, str]]:
    return [(i/(g-1), _interp("Viridis", 0.55 + 0.45*i/(g-1)))
            for i in range(g)]

def colour_quad_mod_g_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    f = _ensure_int_freq(f)
    if tag == "full":
        A, B = a, b
        side = 2 * p
        top, right = A >= p, B >= p
        out = np.empty_like(a)
        for quad, m in {
            "BL": (~top) & (~right),
            "BR": (~top) & right,
            "TL": top & (~right),
            "TR": top & right,
        }.items():
            out[m], _, _,_ = colour_quad_mod_g_no_fb(a[m] % p, b[m] % p, p, f, quad)
        caption = f"(a+-b) mod {g} [g=p//gcd({f},{p})]"
        return out, caption, int(out.max())+1, _DEFAULT
    base = _base_val(a, b, tag)
    colour = base % g
    expr = _QUAD_EXPR[tag]
    caption = f"({expr}) mod {g}"
    return colour, caption, int(colour.max())+1, _DEFAULT

def colour_quad_mul_f(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float,str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    f = _ensure_int_freq(f)
    if tag == "full":
        A, B = a, b
        top = A >= p
        right = B >= p
        base_raw = np.where(right, B - A, A + B)
        c = (f * base_raw) % p
        mask_front = (~top & ~right) | (top & right)
        colour_idx = np.where(mask_front, c, p + c)
        caption = f"f*(b+-a) mod {p} (BL/TR->Viridis dark; TL/BR->Viridis bright)"
        return colour_idx, caption, int(colour_idx.max()) + 1, build_split_scale_red_orange(p)
    base = _base_val(a, b, tag)
    is_rain = tag in ("BR", "TL")
    colour = (f * base) % p
    expr = _QUAD_EXPR[tag]
    caption = f"{f}·({expr}) mod {p}"
    cmap = build_ro_scale(p) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap

def colour_quad_mod_g(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float,str]]]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    step = step_size(f,p)
    remap = f if step !=1 else 1
    if tag == "full":
        A, B = a, b
        top = A >= p
        right = B >= p
        base_raw = np.where(right, B - A, A + B)
        c = (remap * base_raw) % g
        mask_front = (~top & ~right) | (top & right)
        colour_idx = np.where(mask_front, c, g + c)
        caption = f"(b+-a) mod {g} (BL/TR->Viridis dark; TL/BR->Viridis bright)"
        return colour_idx, caption, int(colour_idx.max()) + 1, build_split_scale_red_orange(g)
    base = _base_val(a, b, tag)
    is_rain = tag in ("BR", "TL")
    colour = (remap*base) % g
    expr = _QUAD_EXPR[tag]
    caption = f"({expr}) mod {g}"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max()) + 1, cmap

def colour_quad_a_only(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    step = step_size(f,p)
    remap = f if step !=1 else 1
    if tag == "full":
        A, B = a, b
        top = A >= p
        is_rain = top
        base = (remap*A) % g
        colour = np.where(is_rain, g + base, base)
        caption = (
            f"Viridis dark (a<p): a mod {g}; "
            f"Viridis bright (a>p): a mod {g}"
        )
        return colour, caption, 2 * g, build_split_scale_red_orange(g)
    is_rain = tag in ("TR", "TL")
    base = (np.asarray(a)*remap) % g
    colour = base
    caption = f"a mod {g}  ({'Viridis bright' if is_rain else 'Viridis dark'})"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max())+1, cmap

def colour_quad_b_only(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int]:
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    step = step_size(f,p)
    remap = f if step !=1 else 1
    if tag == "full":
        A, B = a, b
        right = B >= p
        is_rain = right
        base = (remap*B) % g
        colour = np.where(is_rain, g + base, base)
        caption = (
            f"Viridis dark (b<p): a mod {g}; "
            f"Viridis bright (b>p): a mod {g}"
        )
        return colour, caption, 2 * g, build_split_scale_red_orange(g)
    is_rain = tag in ("BR", "TR")
    base = (remap*np.asarray(b)) % g
    colour = base
    caption = f"b mod {g}  ({'Viridis bright' if is_rain else 'Viridis dark'})"
    cmap = build_ro_scale(g) if is_rain else build_vi_scale(p)
    return colour, caption, int(colour.max())+1, cmap

def colour_quad_a_only_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: str,                    、
):
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    if tag == "full":
        A = a % p
        colour = A % g
    else:
        colour = a % g
    caption = f"a mod {g}"
    pbar = int(g)
    return colour, caption, pbar, _DEFAULT

def colour_quad_b_only_no_fb(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: str,
):
    f = _ensure_int_freq(f)
    g = p // math.gcd(p, f) if math.gcd(p, f) != 0 else p
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    if tag == "full":
        B = b % p
        colour = B % g
    else:
        colour = b % g
    caption = f"b mod {g}"
    pbar = int(g)
    return colour, caption, pbar, _DEFAULT

def _order_by_step(local_vals: np.ndarray, g: int, d: int) -> np.ndarray:
    resid = np.asarray(local_vals, int) % g
    if resid.size == 0:
        return np.arange(0, 0, dtype=int)
    r0 = int(np.min(resid))
    seq = [(r0 + t * d) % g for t in range(g)]
    pos = {r: i for i, r in enumerate(seq)}
    order_keys = np.vectorize(pos.get)(resid)
    return np.argsort(order_keys, kind="stable")

def lines_a_mod_g_step(a_vals, b_vals, p, g, d):
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    out = []
    for b_fix in range(g):
        idx = np.where((b_vals == b_fix) & (a_vals >= 0) & (a_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(a_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "blue", f"BL_b{b_fix}"))
    for b_fix in range(g):
        idx = np.where((b_vals == b_fix) & (a_vals >= p) & (a_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((a_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "blue", f"TL_b{b_fix}"))
    for b_fix in range(p, p + g):
        idx = np.where((b_vals == b_fix) & (a_vals >= 0) & (a_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(a_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "red", f"BR_b{b_fix}"))
    for b_fix in range(p, p + g):
        idx = np.where((b_vals == b_fix) & (a_vals >= p) & (a_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((a_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "red", f"TR_b{b_fix}"))
    return out

def lines_b_mod_g_step(a_vals, b_vals, p, g, d):
    a_vals = np.asarray(a_vals, int)
    b_vals = np.asarray(b_vals, int)
    out = []
    for a_fix in range(g):
        idx = np.where((a_vals == a_fix) & (b_vals >= 0) & (b_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(b_vals[idx] % g, g, d)
            out.append((idx[order], "solid", "blue", f"BL_a{a_fix}"))
    for a_fix in range(g):
        idx = np.where((a_vals == a_fix) & (b_vals >= p) & (b_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((b_vals[idx] - p) % g, g, d)
            out.append((idx[order], "solid", "red", f"BR_a{a_fix}"))
    for a_fix in range(p, p + g):
        idx = np.where((a_vals == a_fix) & (b_vals >= 0) & (b_vals < g))[0]
        if idx.size >= 2:
            order = _order_by_step(b_vals[idx] % g, g, d)
            out.append((idx[order], "dash", "blue", f"TL_a{a_fix}"))
    for a_fix in range(p, p + g):
        idx = np.where((a_vals == a_fix) & (b_vals >= p) & (b_vals < p + g))[0]
        if idx.size >= 2:
            order = _order_by_step((b_vals[idx] - p) % g, g, d)
            out.append((idx[order], "dash", "red", f"TR_a{a_fix}"))
    return out

def lines_c_mod_g_step(a_vals, b_vals, p, g, d):
    A = np.asarray(a_vals, int)
    B = np.asarray(b_vals, int)
    top, right = (A >= p), (B >= p)
    mask_BL = (~top) & (~right) & (A < g)       & (B < g)
    mask_TR = ( top) &  (right) & (A >= p) & (A < p+g) & (B >= p) & (B < p+g)
    mask_BR = (~top) &  (right) & (A < g)       & (B >= p) & (B < p+g)
    mask_TL = ( top) & (~right) & (A >= p) & (A < p+g) & (B < g)
    out = []
    for r in range(g):
        idx = np.where(mask_BL & (((B % g) + (A % g)) % g == r))[0]
        if idx.size >= 2:
            order = _order_by_step(A[idx] % g, g, d)
            out.append((idx[order], "solid", "blue",  f"BL_r{r}"))
        idx = np.where(mask_TR & (((B - A) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(((A[idx] - p) % g), g, d)
            out.append((idx[order], "solid", "red",   f"TR_r{r}"))
        idx = np.where(mask_BR & (((B - A) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(A[idx] % g, g, d)
            out.append((idx[order], "solid", "green", f"BR_r{r}"))
        idx = np.where(mask_TL & ((((B % g) + ((A - p) % g)) % g) == r))[0]
        if idx.size >= 2:
            order = _order_by_step(((A[idx] - p) % g), g, d)
            out.append((idx[order], "solid", "purple", f"TL_r{r}"))
    return out

def _build_discrete_scale(colors: list[str]) -> list[tuple[float, str]]:
    k = len(colors)
    if k == 0:
        return [(0.0, "#000000"), (1.0, "#000000")]
    scale = []
    for i, c in enumerate(colors):
        lo = i / k
        hi = (i + 1) / k
        scale.append((lo, c))
        scale.append((hi, c))
    return scale

_MOD2_COLORS = ["#3B528B",
                "#21918C",
                "#5EC962",
                "#FDE725"]

_FB2_COLORS  = ["#2C728E",
                "#95D840"]

def colour_pair_mod2(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    if tag == "full":
        A, B = a, b
    else:
        A, B = a, b
    am2 = (A & 1)
    bm2 = (B & 1)
    colour = (am2 << 1) | bm2
    caption = "(a mod 2, b mod 2) in {00,01,10,11} -> {0,1,2,3}"
    pbar = 4
    cmap = _build_discrete_scale(_MOD2_COLORS)
    return colour, caption, pbar, cmap

def colour_front_back_by_c(
    a: np.ndarray | list[int],
    b: np.ndarray | list[int],
    p: int,
    f: int,
    tag: Quadrant,
) -> Tuple[np.ndarray, str, int, list[tuple[float, str]]]:
    a = np.asarray(a, int)
    b = np.asarray(b, int)
    if tag == "full":
        A, B = a, b
        top   = (A >= p)
        right = (B >= p)
        is_front = ((~top) & (~right)) | (top & right)
        colour = np.where(is_front, 0, 1)
        caption = "Sign +1 (BL/TR) vs Sign -1 (TL/BR) quadrants"
        pbar = 2
        cmap = _build_discrete_scale(_FB2_COLORS)
        return colour, caption, pbar, cmap
    if tag in ("BL", "TR"):
        cls = 0
    elif tag in ("TL", "BR"):
        cls = 1
    else:
        raise ValueError("tag must be one of BL/BR/TL/TR or 'full'")
    colour = np.full_like(a, fill_value=cls, dtype=int)
    caption = f"{'Sign +1' if cls==0 else 'Sign -1'} quadrant ({tag})"
    pbar = 2
    cmap = _build_discrete_scale(_FB2_COLORS)
    return colour, caption, pbar, cmap
