from __future__ import annotations

import io
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import gradio as gr
import pandas as pd
from PIL import Image
from datasets import load_from_disk, Dataset, DatasetDict


# -------------------------
# Config and helpers
# -------------------------


@dataclass
class AppConfig:
    data_dir: str = "./data/mathlens"


DOWNSTREAM_PROMPT_KEYS = [
    "query_vis_wo",
    "query_text_wo",
    "query_vis_cot",
    "query_text_cot",
]
PERCEPTION_PROMPT_KEYS = [
    "query_wo",
    "query_cot",
]

IMG_COL_KEYS = ["decoded_image", "image", "image_key"]  # be lenient

BASIC_COLS = [
    "id",
    "problem_id",
    "vqa_id",
    "modification_type",
    "question_type",
    "image_key",
    "question",
    "question_text",
    "answer",
]


def _to_pil(x) -> Optional[Image.Image]:
    """Accepts various HF image representations -> PIL.Image or None."""
    if x is None:
        return None
    if isinstance(x, Image.Image):
        return x
    # Some HF datasets store { 'path': str } or { 'bytes': b'...' }
    if isinstance(x, dict):
        if "bytes" in x and isinstance(x["bytes"], (bytes, bytearray)):
            return Image.open(io.BytesIO(x["bytes"]))
        if "path" in x:
            try:
                return Image.open(x["path"]).convert("RGB")
            except Exception:
                return None
    # If it's a path-like
    if isinstance(x, (str, Path)) and Path(x).exists():
        try:
            return Image.open(x).convert("RGB")
        except Exception:
            return None
    return None


def _detect_split_type(split_name: str) -> str:
    return "perception" if split_name.startswith("perception_") else "downstream"


def _available_prompt_keys(split_type: str) -> List[str]:
    return (
        PERCEPTION_PROMPT_KEYS if split_type == "perception" else DOWNSTREAM_PROMPT_KEYS
    )


def _stringify_answer(ans) -> str:
    # answers are stored as Sequence(Value("string")) -> list[str]
    if ans is None:
        return ""
    if isinstance(ans, (list, tuple)):
        return ", ".join(map(str, ans))
    return str(ans)


def _dataset_to_dataframe(ds: Dataset, cols: List[str]) -> pd.DataFrame:
    present = [c for c in cols if c in ds.column_names]
    # avoid loading giant images into the table; we'll show only metadata
    rows = ds.remove_columns(
        [c for c in ds.column_names if c not in present]
    ).to_pandas()
    if "answer" in rows.columns:
        rows["answer"] = rows["answer"].apply(_stringify_answer)
    return rows


# -------------------------
# Gradio callbacks
# -------------------------

app_state: Dict[str, object] = {"dsdict": None, "splits": []}


def cb_load_dataset(path_str: str) -> Tuple[gr.Dropdown, gr.Markdown]:
    p = Path(path_str).expanduser().resolve()
    if not p.exists():
        return gr.Dropdown(choices=[], value=None), gr.Markdown(
            f"**Error:** path not found: `{p}`"
        )
    try:
        dsdict: DatasetDict = load_from_disk(str(p))
    except Exception as e:
        return gr.Dropdown(choices=[], value=None), gr.Markdown(
            f"**Failed to load dataset:** {e}"
        )

    app_state["dsdict"] = dsdict
    splits = list(dsdict.keys())
    app_state["splits"] = splits
    meta = [
        f"- **Path**: `{p}`",
        f"- **Splits** ({len(splits)}): {', '.join(splits)}",
        f"- **Sizes**: { {k: len(v) for k, v in dsdict.items()} }",
    ]
    return gr.Dropdown(
        choices=splits, value=(splits[0] if splits else None)
    ), gr.Markdown("\n".join(meta))


def cb_split_changed(split: str):
    dsdict: DatasetDict = app_state.get("dsdict")  # type: ignore
    if not dsdict or split not in dsdict:
        return (
            gr.update(maximum=0, value=0),
            gr.update(choices=[], value=None),
            gr.update(value=None),
            gr.update(value=pd.DataFrame()),
            gr.update(value="(no split loaded)"),
        )

    ds = dsdict[split]
    n = len(ds)
    split_type = _detect_split_type(split)
    prompt_keys = _available_prompt_keys(split_type)

    # Build a small gallery preview (first 12 images if available)
    thumbs = []
    for i in range(min(12, n)):
        row = ds[i]
        img = None
        for k in IMG_COL_KEYS:
            if k in row:
                img = _to_pil(row[k])
                if img is not None:
                    break
        if img is not None:
            thumbs.append(img)

    # Build table with metadata
    table_cols = [c for c in BASIC_COLS if c in ds.column_names]
    df = _dataset_to_dataframe(ds, table_cols)

    return (
        gr.update(maximum=max(0, n - 1), value=0),
        gr.update(choices=prompt_keys, value=(prompt_keys[0] if prompt_keys else None)),
        gr.update(value=thumbs),
        gr.update(value=df),
        gr.update(value=f"**{split}** | **type:** {split_type} | **examples:** {n}"),
    )


def _get_row(dsdict: DatasetDict, split: str, idx: int) -> dict:
    ds = dsdict[split]
    idx = max(0, min(idx, len(ds) - 1))
    return ds[int(idx)]


def cb_render_example(
    split: str, idx: int, prompt_key: Optional[str]
) -> Tuple[Image.Image | None, str, str, str, str]:
    dsdict: DatasetDict = app_state.get("dsdict")  # type: ignore
    if not dsdict or split not in dsdict:
        return None, "", "", "", ""

    row = _get_row(dsdict, split, idx)

    # Image
    img = None
    for k in IMG_COL_KEYS:
        if k in row:
            img = _to_pil(row[k])
            if img is not None:
                break

    # Text fields
    q = row.get("question_text") or row.get("question") or ""
    ans = _stringify_answer(row.get("answer"))

    meta_keys = [
        "id",
        "problem_id",
        "vqa_id",
        "modification_type",
        "question_type",
        "image_key",
    ]
    meta_items = [f"**{k}**: {row[k]}" for k in meta_keys if k in row]
    meta_md = "\n".join(meta_items) if meta_items else ""

    prompt_txt = row.get(prompt_key, "") if prompt_key else ""

    # Also expose the raw row as pretty JSON (without image bytes)
    safe_row = {k: v for k, v in row.items() if k not in IMG_COL_KEYS}
    raw_json = json.dumps(safe_row, ensure_ascii=False, indent=2)

    return img, q, ans, prompt_txt, raw_json


def cb_search(split: str, query: str) -> Tuple[int, str]:
    dsdict: DatasetDict = app_state.get("dsdict")  # type: ignore
    if not dsdict or split not in dsdict:
        return 0, ""
    ds = dsdict[split]
    if not query:
        return 0, ""
    pat = re.compile(re.escape(query), re.IGNORECASE)
    for i in range(len(ds)):
        row = ds[i]
        text = " ".join(str(row.get(k, "")) for k in ["id", "problem_id", "question", "question_text"])  # type: ignore
        if pat.search(text):
            return i, f"Found match at index {i}"
    return 0, "No match; showing index 0"


def cb_export_filtered(split: str, filter_text: str) -> str:
    dsdict: DatasetDict = app_state.get("dsdict")  # type: ignore
    if not dsdict or split not in dsdict:
        return "No dataset loaded."
    ds = dsdict[split]
    pat = re.compile(re.escape(filter_text), re.IGNORECASE) if filter_text else None
    out = []
    for i in range(len(ds)):
        row = ds[i]
        if pat:
            text = " ".join(str(row.get(k, "")) for k in ["id", "problem_id", "question", "question_text"])  # type: ignore
            if not pat.search(text):
                continue
        safe_row = {k: v for k, v in row.items() if k not in IMG_COL_KEYS}
        out.append(safe_row)
    out_path = Path("export_filtered.jsonl").resolve()
    with out_path.open("w", encoding="utf-8") as f:
        for item in out:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
    return f"Exported {len(out)} rows to {out_path}"


# -------------------------
# Build UI
# -------------------------


def build_ui(default_path: str = AppConfig.data_dir):
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 📊 MathLens-style Dataset Viewer")
        with gr.Row():
            data_dir = gr.Textbox(
                value=default_path, label="Dataset path (load_from_disk)"
            )
            load_btn = gr.Button("Load dataset", variant="primary")
        info_md = gr.Markdown("(nothing loaded)")

        with gr.Row():
            split_dd = gr.Dropdown(label="Split", choices=[])
            prompt_dd = gr.Dropdown(label="Prompt field", choices=[])
            idx_slider = gr.Slider(0, 0, value=0, step=1, label="Index")

        with gr.Row():
            search_box = gr.Textbox(label="Search (id / problem_id / question)")
            search_btn = gr.Button("Find first match")
            search_status = gr.Markdown(visible=True)

        gr.Markdown("### Quick gallery preview (first 12 images)")
        gallery = gr.Gallery(height=200, preview=True)

        gr.Markdown("### Table (metadata only; images excluded)")
        table = gr.Dataframe(wrap=True)

        gr.Markdown("### Example viewer")
        with gr.Row():
            img_out = gr.Image(label="Image", interactive=False)
            with gr.Column():
                q_out = gr.Textbox(
                    label="Question / question_text",
                    interactive=False,
                    show_copy_button=True,
                )
                ans_out = gr.Textbox(
                    label="Answer", interactive=False, show_copy_button=True
                )
                prompt_out = gr.Textbox(
                    label="Selected prompt", interactive=False, show_copy_button=True
                )
        raw_json_out = gr.Code(label="Raw example (JSON)")

        with gr.Row():
            filter_box = gr.Textbox(label="Filter & export (optional substring)")
            export_btn = gr.Button("Export filtered to JSONL")
            export_status = gr.Markdown()

        # Wiring
        load_btn.click(cb_load_dataset, inputs=[data_dir], outputs=[split_dd, info_md])
        split_dd.change(
            cb_split_changed,
            inputs=[split_dd],
            outputs=[idx_slider, prompt_dd, gallery, table, info_md],
        )
        idx_slider.change(
            cb_render_example,
            inputs=[split_dd, idx_slider, prompt_dd],
            outputs=[img_out, q_out, ans_out, prompt_out, raw_json_out],
        )
        prompt_dd.change(
            cb_render_example,
            inputs=[split_dd, idx_slider, prompt_dd],
            outputs=[img_out, q_out, ans_out, prompt_out, raw_json_out],
        )
        search_btn.click(
            cb_search,
            inputs=[split_dd, search_box],
            outputs=[idx_slider, search_status],
        )
        export_btn.click(
            cb_export_filtered, inputs=[split_dd, filter_box], outputs=[export_status]
        )

    return demo


if __name__ == "__main__":
    import tyro

    @dataclass
    class CLI:
        data_dir: str = AppConfig.data_dir
        server_port: int = 7860
        server_name: str = "0.0.0.0"

    args = tyro.cli(CLI)
    demo = build_ui(args.data_dir)
    demo.launch(server_name=args.server_name, server_port=args.server_port)
