"""Preparation module for generating JSONL batch files for LLM evaluation."""
import copy
import hashlib
import json
import random
from pathlib import Path
from typing import Optional, Union
from xml.dom import minidom
from xml.etree.ElementTree import Element, SubElement, tostring

import fire
import pandas as pd

from src.llm_eval.questions import PROMPTS, SYSTEM_PROMPTS


# Swap labels ablation


def _stable_seed_from_layout(layout_id: Union[int, str]) -> int:
    """Stable 32-bit seed from layout_id (consistent across runs)."""
    return int(hashlib.md5(str(layout_id).encode()).hexdigest(), 16) % (2**32)


def swap_object_labels(
    room: dict,
    strategy: str = "rotate",  # "rotate" | "reverse" | "shuffle"
    seed: Optional[int] = None,  # used only when strategy="shuffle"
) -> dict:
    """
    Return a copy of room with all object labels swapped.

    Args:
        room: Room data dictionary
        strategy: Swap strategy - "rotate", "reverse", or "shuffle"
        seed: Random seed for shuffle strategy

    Returns:
        Copy of room with swapped labels
    """
    out = copy.deepcopy(room)
    objs = out.get("objects")
    if not isinstance(objs, list) or len(objs) < 2:
        return out  # nothing to do

    labels = [obj.get("label", "") for obj in objs]
    new_labels = labels[:]

    if strategy == "rotate":
        new_labels = labels[1:] + labels[:1]
    elif strategy == "reverse":
        new_labels = list(reversed(labels))
    elif strategy == "shuffle":
        rng = random.Random(seed if seed is not None else 0)
        rng.shuffle(new_labels)
        # ensure we actually swapped at least one position; if not, rotate
        if new_labels == labels and len(labels) > 1:
            new_labels = labels[1:] + labels[:1]
    else:
        raise ValueError("strategy must be 'rotate', 'reverse', or 'shuffle'")

    for obj, new_label in zip(objs, new_labels):
        obj["label"] = new_label
    return out


# XML ablation


_SINGULAR_OVERRIDES = {
    "walls": "wall",
    "objects": "object",
    "points": "point",
    "windows": "window",
    "doors": "door",
    "room_boundary": "point",
}


def _singularize(tag: str) -> str:
    """Convert plural tag to singular form."""
    if tag in _SINGULAR_OVERRIDES:
        return _SINGULAR_OVERRIDES[tag]
    return tag[:-1] if tag.endswith("s") and len(tag) > 1 else "item"


def _dict_to_xml(parent: Element, key: str, value):
    """
    Recursively convert Python data to XML by attaching children to parent.

    Args:
        parent: Parent XML element
        key: Tag name for the element
        value: Value to convert (dict, list, or scalar)
    """
    if isinstance(value, dict):
        node = SubElement(parent, key)
        for k, v in value.items():
            _dict_to_xml(node, k, v)
    elif isinstance(value, list):
        item_tag = _singularize(key)
        list_parent = SubElement(parent, key)
        for item in value:
            if isinstance(item, (dict, list)):
                _dict_to_xml(list_parent, item_tag, item)
            else:
                item_el = SubElement(list_parent, item_tag)
                item_el.text = "" if item is None else str(item)
    else:
        node = SubElement(parent, key)
        node.text = "" if value is None else str(value)


def room_to_xml(room: dict) -> str:
    """
    Build <room> ... </room> with top-level known fields first to keep it tidy.
    """
    root = Element("room")
    # Put a few common top-level fields first if present
    for top in ("layout_id", "room_type", "units"):
        if top in room:
            _dict_to_xml(root, top, room[top])
    # Add the rest (skipping ones already added)
    for k, v in room.items():
        if k in ("layout_id", "room_type", "units"):
            continue
        _dict_to_xml(root, k, v)

    # Pretty-print without XML declaration
    ugly = tostring(root, encoding="utf-8")
    pretty = minidom.parseString(ugly).toprettyxml(
        indent="  ", encoding="utf-8"
    ).decode("utf-8")
    # Strip the XML declaration line
    lines = [ln for ln in pretty.splitlines() if ln.strip()]
    if lines and lines[0].startswith("<?xml"):
        lines = lines[1:]
    return "\n".join(lines)


def _format_user_prompt(
    dataset: str,
    row: pd.Series,
    room: dict,
    room_type: str,
    layout_id: str,
    layout_dir: str
) -> str:
    """Safely format the user prompt for a given dataset."""
    template = PROMPTS[dataset]

    # deterministic shuffle per layout_id:
    seed = _stable_seed_from_layout(layout_id)
    room_swapped = swap_object_labels(room, strategy="shuffle", seed=seed)

    # Dump the swapped room to disk for inspection / reuse
    swapped_dir = Path(layout_dir) / "swapped_labels"
    swapped_dir.mkdir(parents=True, exist_ok=True)
    swapped_path = swapped_dir / f"room_{layout_id}.json"
    print(f"Writing swapped room to {swapped_path}...")
    with open(swapped_path, "w", encoding="utf-8") as _f:
        json.dump(room_swapped, _f)

    fmt = {
        "room_type": room_type,
        "room": json.dumps(room_swapped, ensure_ascii=False),
        # optional fields that some prompts expect:
        "obj1": row.get("object_1", ""),
        "obj2": row.get("object_2", ""),
        "format": row.get("format", "XML"),
        "clearance": row.get("clearance", ""),
        "object_name": row.get("object_name", ""),
        "object_width": row.get("object_width", ""),
        "object_depth": row.get("object_depth", ""),
        "object_to_move": row.get("object_to_move", ""),
        "direction": row.get("direction", ""),
    }
    return template.format(**fmt)


def create_jsonl_for_batch(
    df: pd.DataFrame,
    dataset: str,
    room_type: str,
    max_tokens: int,
    layout_dir: Union[str, Path],
    model_id: str,
    output_jsonl_path: Union[str, Path],
) -> None:
    """Create JSONL batch file for LLM inference."""
    requests = []
    layout_dir = Path(layout_dir)
    output_jsonl_path = Path(output_jsonl_path)

    sys_prompt = SYSTEM_PROMPTS.get(dataset, None)

    for _, row in df.iterrows():
        layout_id = row["layout_id"]

        layout_path = layout_dir / f"room_{layout_id}.json"
        with open(layout_path, "r", encoding="utf-8") as f:
            room = json.load(f)

        user_prompt = _format_user_prompt(
            dataset, row, room, room_type, layout_id, layout_dir
        )

        messages = []
        if sys_prompt:  # only include when defined
            messages.append({"role": "system", "content": sys_prompt})
        messages.append({"role": "user", "content": user_prompt})

        if "gpt-5" in model_id:
            request_entry = {
                "custom_id": f"request-{row['layout_id']}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model_id,
                    "messages": messages,
                    "max_completion_tokens": max_tokens,
                    "verbosity": "low",
                    "reasoning_effort": "minimal",
                    "response_format": {"type": "text"},
                },
            }
        else:
            request_entry = {
                "custom_id": f"request-{row['layout_id']}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model_id,
                    "messages": messages,
                    "max_tokens": max_tokens,
                    "temperature": 0.0,
                },
            }

        requests.append(request_entry)

    print(f"Writing {len(requests)} requests to {output_jsonl_path}...")


def generate_batches(
    layout_dir_template: str = "data/generated_data/{room_type}",
    input_csv_template: str = "benchmark/{dataset}/{dataset}_qa_{room_type}.csv",
    output_jsonl_template: str = (
        "qa_jsonl_ablation/swap_labels/{dataset}/{room_type}/"
        "{model_name}_{max_tokens}.jsonl"
    ),
):
    """Generate batch JSONL files for multiple models and datasets."""
    # Model configurations
    models_list = [
        (12288, "openai/gpt-oss-120b"),
        (8192, "openai/gpt-oss-20b"),
        (12288, "Qwen/Qwen3-235B-A22B-Instruct-2507"),
    ]

    for room_type in ["hssd_data_simplified"]:
        for dataset in [
            "obstruction",
            "view_angle",
            "repositioning",
        ]:
            if room_type == "hssd_data_simplified":
                layout_dir = "data/hssd_data/json_simplified"
            else:
                layout_dir = layout_dir_template.format(room_type=room_type)
            input_csv_path = input_csv_template.format(
                dataset=dataset, room_type=room_type
            )

            df = pd.read_csv(input_csv_path)
            print(df[:1])

            for max_tokens, model_id in models_list:
                model_name = model_id.split("/")[-1]
                output_jsonl_path = output_jsonl_template.format(
                    dataset=dataset,
                    room_type=room_type,
                    model_name=model_name,
                    max_tokens=max_tokens
                )
                output_jsonl_path = Path(output_jsonl_path)
                output_jsonl_path.parent.mkdir(parents=True, exist_ok=True)

                create_jsonl_for_batch(
                    df=df,
                    dataset=dataset,
                    room_type=room_type,
                    max_tokens=max_tokens,
                    layout_dir=layout_dir,
                    model_id=model_id,
                    output_jsonl_path=output_jsonl_path
                )

            print(
                f"Batch JSONL files written for dataset '{dataset}' "
                f"and room type '{room_type}'."
            )
