#!/usr/bin/env python3
"""
Render a Qwen chat-template from the trace files produced by WebGen-Agent.
"""

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import re
from tqdm import tqdm
# from transformers import AutoTokenizer
import random
import shutil
import time
import copy
import re

from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial

from safe_shell_execute import safe_shell_execute
import traceback

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import get_validation_prompt, get_core_system_prompt, TEMPLATES, ToolRegistry, get_planning_prompt
from tools import (
    ReadFileTool, 
    WriteFileTool,
    ListDirectoryTool, 
    GlobTool,
    ShellTool,
    EditTool,
    GrepTool,
    ReadManyFilesTool,
    BackendTestTool,
    FrontendTestTool,
    get_all_tools
)

from convert_messages_replace_working_dir import convert_messages_replace_working_dir

# -----------------------------------------------------------------------------
# config
# -----------------------------------------------------------------------------
MODEL_NAME = "/root/user/models/Qwen3-Coder-30B-A3B-Instruct"
SCORE_WEIGHT = 0.9
BACKEND_TEST_THRESHOLD = 4
FRONTEND_TEST_THRESHOLD = 4
RESPONSE_NONEMPTY_LENGTH = 10

def load_jsonl(in_file):
    datas = []
    with open(in_file, "r", encoding="utf-8") as f:
        for line in tqdm(f):
            datas.append(json.loads(line))
    return datas


def save_jsonl(datas, out_file):
    with open(out_file, "w", encoding="utf-8") as f:
        for data in tqdm(datas):
            f.write(json.dumps(data, ensure_ascii=False) + "\n")


# -----------------------------------------------------------------------------
# helpers
# -----------------------------------------------------------------------------
def ensure_mapping(value: Any) -> Dict[str, Any]:
    "Return a mapping (empty if value is None); crash if value is a list."
    return value if isinstance(value, dict) else {}


def clean_tool_schema(raw: Dict[str, Any]) -> Dict[str, Any]:
    """
    Strip possible OpenAI wrapper and make sure the inner dict has
    parameters.properties as a mapping so Jinja's |items filter succeeds.
    """
    tool = raw.get("function", raw)  # unwrap "type":"function" wrapper if any

    # guarantee required keys
    tool.setdefault("description", "")
    tool.setdefault("parameters", {"type": "object", "properties": {}})
    params = tool["parameters"]
    params.setdefault("type", "object")
    params["properties"] = ensure_mapping(params.get("properties"))
    params.setdefault("required", [])

    # For every property, make sure it is a mapping
    for key, spec in list(params["properties"].items()):
        if not isinstance(spec, dict):
            params["properties"][key] = {"type": "string"}

    return tool


def clean_messages(raw_msgs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Normalise message structure so the template expectations are met.
    """
    cleaned: List[Dict[str, Any]] = []
    for m in raw_msgs:
        role = m.get("role")
        if role == "assistant" and "tool_calls" in m:
            # enforce list
            tc_list = m["tool_calls"]
            if isinstance(tc_list, dict):
                tc_list = [tc_list]
            new_tc_list = []
            for tc in tc_list:
                func = tc.get("function", tc)  # may already be at root
                # arguments must be mapping
                args = func.get("arguments", {})
                if isinstance(args, str):
                    try:
                        args = json.loads(args)
                    except Exception:
                        args = {}
                if not isinstance(args, dict):
                    args = {}
                func["arguments"] = args
                new_tc_list.append({"function": func})
            m["tool_calls"] = new_tc_list
            # content may be null – template wants string
            m["content"] = m.get("content") or ""
        cleaned.append(m)
    return cleaned


# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def load_json(path: Path) -> Dict[str, Any]:
    with path.open(encoding="utf-8") as f:
        return json.load(f)
    

def get_messages_tools(log_dir, step_idx):
    log_dir = Path(log_dir)
    llm_response = log_dir / f"{step_idx}_llm_response.json"
    tool_response = log_dir / f"{step_idx}_tool_response.json"

    if tool_response.is_file():
        llm = load_json(llm_response)
        tool = load_json(tool_response)

        raw_tools = llm["request"]["tool_schemas"]
        tools = [clean_tool_schema(t) for t in raw_tools]

        raw_messages = tool["request"]["messages"] + tool["response"]
        messages = clean_messages(raw_messages)
        is_tool = True
    else:
        llm = load_json(llm_response)

        raw_tools = llm["request"]["tool_schemas"]
        tools = [clean_tool_schema(t) for t in raw_tools]

        raw_messages = llm["request"]["messages"] + [llm["response"]]
        messages = clean_messages(raw_messages)
        is_tool = False

    return messages, tools, is_tool


def get_compress_history(log_dir, step_idx):
    log_dir = Path(log_dir)
    compress_history_log = log_dir / f"{step_idx}_compress_history.json"

    compress_history = load_json(compress_history_log)

    raw_messages = compress_history["request"]["messages"] + [compress_history["response"]]
    messages = clean_messages(raw_messages)

    return messages

Trajectory = Dict[str, Optional[object]]   # {"step": int, "kind": str, "messages": list, "tools": list|None}


def get_compressed(log_dir, n):
    data = load_json(Path(os.path.join(log_dir, f"{n}_compress_history.json")))
    compressed_messages = clean_messages(json.loads(data["request"]["messages"][1]["content"]))
    return compressed_messages


def get_trajectories(log_dir: str | Path) -> List[Trajectory]:
    log_dir = Path(log_dir)
    if not log_dir.is_dir():
        raise FileNotFoundError(log_dir)

    traj: List[Trajectory] = []

    comp_re = re.compile(r"^(\d+)_compress_history\.json$")
    compress_steps: List[int] = sorted(
        int(m.group(1)) for p in log_dir.iterdir()
        if (m := comp_re.match(p.name))
    )

    tools = None

    compressed_section = []
    new_messages_start = 0
    idx = 0
    full_messages = []
    for n in compress_steps:
        prev = n - 1

        messages, tools, is_tool = get_messages_tools(log_dir, prev)

        compressed_messages = get_compressed(log_dir, n)
        compress_length = len(compressed_messages)

        k = 0
        if new_messages_start > 0:
            for i in range(new_messages_start - 1, -1, -1):
                # mark as old message
                messages[i]["is_old"] = True

                # mark old assistant messages with consistant indexes
                if messages[i]["role"] == "assistant" and i >= 4:
                    messages[i]["msg_id"] = idx - k
                    k += 1
        
        for i in range(new_messages_start, len(messages), 1):
            # add new indexes for new messages
            if messages[i]["role"] == "assistant" and i >= 3:
                idx += 1
                messages[i]["msg_id"] = idx

        full_messages.extend(messages[new_messages_start:])

        traj.append(
            dict(
                step=prev,
                kind="compress_before",
                messages=messages,
                tools=tools,
                compressed_section=compressed_section,
            )
        )

        compressed_section = compressed_messages
        new_messages_start = len(messages) - compress_length + 1
        

    llm_re = re.compile(r"^(\d+)_llm_response\.json$")
    llm_steps = sorted(
        int(m.group(1)) for p in log_dir.iterdir()
        if (m := llm_re.match(p.name))
    )
    if llm_steps:                       # the directory *should* have at least one
        last_llm = llm_steps[-1]
        messages, tools, is_tool = get_messages_tools(log_dir, last_llm)

        k = 0
        if new_messages_start > 0:
            for i in range(new_messages_start - 1, -1, -1):
                # mark as old message
                messages[i]["is_old"] = True

                # mark old assistant messages with consistant indexes
                if messages[i]["role"] == "assistant" and i >= 4:
                    messages[i]["msg_id"] = idx - k
                    k += 1
        
        for i in range(new_messages_start, len(messages), 1):
            # add new indexes for new messages
             if messages[i]["role"] == "assistant" and i >= 3:
                idx += 1
                messages[i]["msg_id"] = idx

        full_messages.extend(copy.deepcopy(messages[new_messages_start:]))

        traj.append(
            dict(
                step=last_llm,
                kind="llm_final",
                messages=messages,
                tools=tools,
                compressed_section=compressed_section,
            )
        )

    traj.append(
        dict(
            step=-1,
            kind="uncompressed_messages",
            messages=full_messages,
            tools=tools,
            compressed_section=[],
        )
    )

    return traj


def process_trajectories(trajectories, working_dir, log_dir):
    for traj in trajectories:
        traj["log_dir"] = log_dir
        traj["orig_working_dir"] = working_dir
        traj["working_dir"] = working_dir
        
    return trajectories


def get_planning_trajectory(log_dir, working_dir):
    data = load_json(Path(os.path.join(log_dir, "0_get_plans.json")))
    messages = data["request"]["messages"]
    messages.append(data["response"])

    return {
        "step": -1,
        "kind": "planning_messages",
        "tools": [],
        "log_dir": log_dir,
        "orig_working_dir": working_dir,
        "working_dir": working_dir,
        "messages": messages,
        "compressed_section": []
    }


def _process_one_sample(
    info_data: Dict[str, Any],
    *,
    log_root_dir: str,
    template: Dict[str, Any],
) -> List[Trajectory] | None:
    """
    Execute the original per-sample pipeline.
    In case of *any* exception we convert it into a picklable form so the parent
    process never sees un-picklable objects such as os.DirEntry.
    """
    try:
        # ---------------------------------------------------------------------
        # original code (unchanged)
        # ---------------------------------------------------------------------
        log_dir = os.path.join(log_root_dir, info_data["id"])
        if not os.path.isfile(os.path.join(log_dir, "finished.json")):
            return info_data

        trajectories = get_trajectories(log_dir)

        if len(trajectories) == 0:
            return info_data

        working_dir = log_dir.replace("logs_root", "workspaces_root")

        trajectories = process_trajectories(
            trajectories,
            working_dir,
            log_dir,
        )

        trajectories.append(get_planning_trajectory(log_dir, working_dir))
        return trajectories

    # -------------------------------------------------------------------------
    # catch *everything*, make it picklable
    # -------------------------------------------------------------------------
    except Exception as exc:
        tb = traceback.format_exc()
        sys.stderr.write(
            f"[worker] sample {info_data.get('id', '<unknown>')} failed:\n{tb}\n"
        )
        sys.stderr.flush()
        return info_data


def get_train(
    log_root_dir: str,
    info_path: str,
    out_path: str,
    unfinished_path: str,
    *,
    template_name: str = "nextjs-nextjs-postresql",
    max_workers: int | None = None,         # let the caller tune the pool size
) -> None:
    # -------------------------------------------------------------------------
    # locate the template once; we will broadcast it to the workers
    # -------------------------------------------------------------------------
    template = next(
        temp for temp in TEMPLATES["templates"] if temp["name"] == template_name
    )

    info_datas = load_jsonl(info_path)

    datas: List[Trajectory] = []
    unfinished_datas = []
    finished_samples = 0

    process_one = partial(
        _process_one_sample,
        log_root_dir=log_root_dir,
        template=template,
    )

    # -------------------------------------------------------------------------
    # fire up the worker pool
    # -------------------------------------------------------------------------
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(process_one, info): info for info in info_datas}

        with tqdm(total=len(futures), desc="processing samples") as pbar:
            for fut in as_completed(futures):
                result = fut.result()
                if isinstance(result, list):
                    # each successful result is a list[Trajectory]
                    datas.extend(result)
                    finished_samples += 1
                elif isinstance(result, dict):
                    unfinished_datas.append(result)
                pbar.update(1)

    print(f"\n\nfinished samples: {finished_samples}")
    for i in range(len(datas)):
        new_working_dir = convert_messages_replace_working_dir(datas[i]["messages"], datas[i]["working_dir"], i)
        new_working_dir = convert_messages_replace_working_dir(datas[i]["compressed_section"], datas[i]["working_dir"], i)
    save_jsonl(datas, out_path)

    print(f"\n\nunfinished samples: {len(unfinished_datas)}")
    save_jsonl(unfinished_datas, unfinished_path)


def main() -> None:
    rnd_idx = 1
    log_root_dir = f"logs_root/model-Qwen3-Coder-30B-A3B-Instruct_hist-100_iter-400_compress-0.5_val-1_sum-5__direct-gen-train"
    info_path = f"data/WebGen-Bench/train.jsonl"
    out_path = f"src/run_process_data/jsonl_files/webgen-instruct_direct-gen_Qwen3-Coder-30B-A3B-Instruct.jsonl"
    unfinished_path = f"src/run_process_data/jsonl_files/webgen-instruct_direct-gen_Qwen3-Coder-30B-A3B-Instruct_unfinished.jsonl"
    max_workers = 16
    get_train(log_root_dir, info_path, out_path, unfinished_path, template_name="nextjs-nextjs-postresql", max_workers=max_workers)


if __name__ == "__main__":
    main()