# mmr_gym/tasks/charts/charts_bar.py
from __future__ import annotations
import random
from typing import Any, Dict, List

from mmr_gym.base import Task
from mmr_gym.registry import register_task
from mmr_gym.config import MAX_BUILD_RETRIES

from mmr_gym.charts import (
    ChartSpec,
    CHART_MIN_K, CHART_MAX_K,
    sample_category_labels,
    sample_percentages_int,
    choose_colors,
    render_bar_chart,
)
# The bar chart is self-labeled (letters on top of each bar); no legend and no numbers.

# Prompts (numbers are hidden; use letters shown above each bar)
PROMPTS_SORT_ASC = [
    "Using the letters shown above the bars, rank the bars from shortest to tallest. Answer with letters only, comma-separated.",
    "Order the bars by height in increasing order (smallest to largest). Respond with the letters above the bars, comma-separated.",
    "From the smallest bar to the largest bar, list the letters shown above each bar. Use commas between letters.",
    "Sort all bars in ascending order of height. Provide the letters (shown above the bars), comma-separated.",
    "Rank the bars from least tall to most tall using the letters above them. Return letters only, comma-separated.",
    "Arrange the bars by height from lowest to highest. Answer using the letters on the bars, comma-separated.",
    "Using the labels above each bar, list bars from shortest to tallest. Letters only, comma-separated.",
    "Give the ascending order of bar heights using the letters above the bars. Return a comma-separated list.",
    "Sort the bar heights from smallest to largest and respond with the letters shown above the bars.",
    "Provide the letters above the bars in increasing order of height (shortest → tallest), comma-separated.",
]

PROMPTS_SORT_DESC = [
    "Using the letters shown above the bars, rank the bars from tallest to shortest. Answer with letters only, comma-separated.",
    "Order the bars by height in decreasing order (largest to smallest). Respond with the letters above the bars, comma-separated.",
    "From the tallest bar to the shortest bar, list the letters shown above each bar. Use commas between letters.",
    "Sort all bars in descending order of height. Provide the letters (shown above the bars), comma-separated.",
    "Rank the bars from most tall to least tall using the letters above them. Return letters only, comma-separated.",
    "Arrange the bars by height from highest to lowest. Answer using the letters on the bars, comma-separated.",
    "Using the labels above each bar, list bars from tallest to shortest. Letters only, comma-separated.",
    "Give the descending order of bar heights using the letters above the bars. Return a comma-separated list.",
    "Sort the bar heights from largest to smallest and respond with the letters shown above the bars.",
    "Provide the letters above the bars in decreasing order of height (tallest → shortest), comma-separated.",
]

# ---------------------------- helpers -----------------------------------

def _sample_distinct_percentages(rng: random.Random, k: int, max_tries: int = 5000) -> List[int]:
    """Sample integer percentages that sum to 100 with all values distinct (>=1)."""
    for _ in range(max_tries):
        p = sample_percentages_int(rng, k, enforce_min1=True)
        if len(set(p)) == k:
            return p
    # Fallback: strictly-increasing construction + randomization.
    base = list(range(1, k + 1))  # sum = k(k+1)/2 <= 55 for k<=10
    R = max(0, 100 - sum(base))
    # distribute the remainder with a nondecreasing vector so order stays increasing,
    # then shuffle to randomize positions
    if k > 1:
        positions = sorted(rng.sample(range(R + k - 1), k - 1)) if (R + k - 1) > 0 else list(range(k - 1))
        inc, prev = [], -1
        for pos in positions + [R + k - 1]:
            inc.append(pos - prev - 1)
            prev = pos
        inc.sort()
        p_sorted = [b + d for b, d in zip(base, inc)]
    else:
        p_sorted = [100]
    rng.shuffle(p_sorted)
    # fix sum if needed (rare)
    diff = 100 - sum(p_sorted)
    if diff != 0:
        i = rng.randrange(k)
        p_sorted[i] = max(1, p_sorted[i] + diff)
    return p_sorted

def _rank_answer(labels: List[str], perc: List[int], direction: str) -> str:
    if direction == "desc":
        order = sorted(range(len(labels)), key=lambda i: (-perc[i], labels[i]))
    else:
        order = sorted(range(len(labels)), key=lambda i: (perc[i], labels[i]))
    return ",".join(labels[i] for i in order)

# ---------------------------- task class --------------------------------

@register_task
class ChartsBarTask(Task):
    """
    Bar chart with K categories (K∈[CHART_MIN_K, CHART_MAX_K]).
    • Only sorting questions (asc/desc, 50/50).
    • Letters are drawn on TOP of each bar; no legend; no numbers.
    • All bars share a single random color.
    """
    name = "charts_bar"

    def __init__(self):
        self.max_retries = int(MAX_BUILD_RETRIES)

    def _sample_spec(self, rng: random.Random) -> ChartSpec:
        seed = rng.randint(0, 2**31 - 1)
        lrng = random.Random(seed)

        k = lrng.randint(CHART_MIN_K, CHART_MAX_K)
        labels = sample_category_labels(lrng, k)

        # Distinct integer percentages (unique heights)
        perc = _sample_distinct_percentages(lrng, k)

        # Counts are just provenance; bars render from percentages
        counts = perc[:]
        value_kind = "percentage"

        # One color for all bars
        one_color, _ = choose_colors(lrng, 1)
        cols = [one_color[0]] * k

        return ChartSpec(
            seed=seed,
            chart_type="bar",
            labels=labels,
            value_kind=value_kind,
            counts=counts,
            percentages_int=perc,
            colors=cols,                 # single color used for every bar
            color_mode="mono",
            width_px=1024,
            height_px=768,
            render_mode="color",
        )

    def generate_instance(self, motif_impls: Dict[str, Any], rng: random.Random):
        for _ in range(self.max_retries):
            spec = self._sample_spec(rng)

            labels = spec.labels
            perc = spec.percentages_int

            direction = "desc" if rng.random() < 0.5 else "asc"
            prompt = random.choice(PROMPTS_SORT_DESC if direction == "desc" else PROMPTS_SORT_ASC)
            answer = _rank_answer(labels, perc, direction)

            image = render_bar_chart(spec, show_values=False)  # letters only on top

            meta = {
                "pattern_kind": "charts",
                "pattern": self.name,
                "variant": {
                    "kind": "sort",
                    "direction": direction,
                    "numbers_shown": False,
                    "single_color": True,
                    "label_on_top": True,
                },
                "question": prompt,
                "answer": answer,
                "chart": {
                    "type": spec.chart_type,
                    "k": len(labels),
                    "labels": labels,
                    "value_kind": spec.value_kind,
                    "percentages": {lab: int(p) for lab, p in zip(labels, perc)},
                    "color_mode": spec.color_mode,
                    "numbers_shown": False,
                },
                "dims": (spec.width_px, spec.height_px),
                "composite_ready": True,
            }
            return image, [spec], meta

        raise RuntimeError(f"{self.name}: failed to build a valid sample after {self.max_retries} attempts.")
