#!/usr/bin/env python3
"""
Render a Qwen chat-template from the trace files produced by WebGen-Agent.
"""

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import re
from tqdm import tqdm
# from transformers import AutoTokenizer
import random
import shutil
import time
import copy

from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial

from wrap_color_theme import wrap_color_theme
from safe_shell_execute import safe_shell_execute

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import get_validation_prompt, get_core_system_prompt, TEMPLATES, ToolRegistry
from tools import (
    ReadFileTool, 
    WriteFileTool,
    ListDirectoryTool, 
    GlobTool,
    ShellTool,
    EditTool,
    GrepTool,
    ReadManyFilesTool,
    BackendTestTool,
    FrontendTestTool,
    get_all_tools
)

from convert_messages_replace_working_dir import convert_messages_replace_working_dir

# -----------------------------------------------------------------------------
# config
# -----------------------------------------------------------------------------
MODEL_NAME = "/root/user/models/Qwen3-Coder-30B-A3B-Instruct"

def load_jsonl(in_file):
    datas = []
    with open(in_file, "r", encoding="utf-8") as f:
        for line in tqdm(f):
            datas.append(json.loads(line))
    return datas


def save_jsonl(datas, out_file):
    with open(out_file, "w", encoding="utf-8") as f:
        for data in tqdm(datas):
            f.write(json.dumps(data, ensure_ascii=False) + "\n")


# -----------------------------------------------------------------------------
# helpers
# -----------------------------------------------------------------------------
def ensure_mapping(value: Any) -> Dict[str, Any]:
    "Return a mapping (empty if value is None); crash if value is a list."
    return value if isinstance(value, dict) else {}


def clean_tool_schema(raw: Dict[str, Any]) -> Dict[str, Any]:
    """
    Strip possible OpenAI wrapper and make sure the inner dict has
    parameters.properties as a mapping so Jinja's |items filter succeeds.
    """
    tool = raw.get("function", raw)  # unwrap "type":"function" wrapper if any

    # guarantee required keys
    tool.setdefault("description", "")
    tool.setdefault("parameters", {"type": "object", "properties": {}})
    params = tool["parameters"]
    params.setdefault("type", "object")
    params["properties"] = ensure_mapping(params.get("properties"))
    params.setdefault("required", [])

    # For every property, make sure it is a mapping
    for key, spec in list(params["properties"].items()):
        if not isinstance(spec, dict):
            params["properties"][key] = {"type": "string"}

    return tool


def clean_messages(raw_msgs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Normalise message structure so the template expectations are met.
    """
    cleaned: List[Dict[str, Any]] = []
    for m in raw_msgs:
        role = m.get("role")
        if role == "assistant" and "tool_calls" in m:
            # enforce list
            tc_list = m["tool_calls"]
            if isinstance(tc_list, dict):
                tc_list = [tc_list]
            new_tc_list = []
            for tc in tc_list:
                func = tc.get("function", tc)  # may already be at root
                # arguments must be mapping
                args = func.get("arguments", {})
                if isinstance(args, str):
                    try:
                        args = json.loads(args)
                    except Exception:
                        args = {}
                if not isinstance(args, dict):
                    args = {}
                func["arguments"] = args
                new_tc_list.append({"function": func})
            m["tool_calls"] = new_tc_list
            # content may be null – template wants string
            m["content"] = m.get("content") or ""
        cleaned.append(m)
    return cleaned


# -----------------------------------------------------------------------------
# main
# -----------------------------------------------------------------------------
def load_json(path: Path) -> Dict[str, Any]:
    with path.open(encoding="utf-8") as f:
        return json.load(f)
    

def get_messages_tools(log_dir, step_idx):
    log_dir = Path(log_dir)
    llm_response = log_dir / f"{step_idx}_llm_response.json"
    tool_response = log_dir / f"{step_idx}_tool_response.json"

    if tool_response.is_file():
        llm = load_json(llm_response)
        tool = load_json(tool_response)

        raw_tools = llm["request"]["tool_schemas"]
        tools = [clean_tool_schema(t) for t in raw_tools]

        raw_messages = tool["request"]["messages"] + tool["response"]
        messages = clean_messages(raw_messages)
        is_tool = True
    else:
        llm = load_json(llm_response)

        raw_tools = llm["request"]["tool_schemas"]
        tools = [clean_tool_schema(t) for t in raw_tools]

        raw_messages = llm["request"]["messages"] + [llm["response"]]
        messages = clean_messages(raw_messages)
        is_tool = False

    background_color = llm["request"].get("background_color", None)
    component_color = llm["request"].get("component_color", None)

    return messages, tools, background_color, component_color, is_tool


def get_compress_history(log_dir, step_idx):
    log_dir = Path(log_dir)
    compress_history_log = log_dir / f"{step_idx}_compress_history.json"

    compress_history = load_json(compress_history_log)

    raw_messages = compress_history["request"]["messages"] + [compress_history["response"]]
    messages = clean_messages(raw_messages)

    return messages, compress_history["request"].get("background_color", None), compress_history["request"].get("component_color", None)


def initialize_tool_registry(working_dir, log_dir):
    tools = get_all_tools(working_dir, log_dir)
    registry = ToolRegistry()
    for tool in tools:
        registry.register_tool(tool)
    return registry


def replace_tool_outputs(messages, replacement_dict):
    for idx, message in enumerate(messages):
        if message["role"] == "system":
            if message["content"].startswith(SYSTEM_PREFIX):
                message["content"] = replacement_dict["system_prompt"]
        elif message["role"] == "assistant":
            msg_id = str(message.get("msg_id", -1))
            if msg_id in replacement_dict:
                tool_content = replacement_dict[msg_id]
                for j in range(idx, len(messages), 1):
                    if messages[j]["role"] == "tool":
                        messages[j]["content"] = tool_content
                        break
    return messages


CMD_FAILED_SUB = "Stdout: ERROR: Previous command still running – send `is_input=true` or interrupt (C-c) first."

def adjust_tool_output(trajectories, adjusted_project_path, log_dir, template):
    template_path = os.path.join(TEMPLATE_ROOT, template["name"])

    if os.path.isdir(adjusted_project_path):
        shutil.rmtree(adjusted_project_path)
    shutil.copytree(template_path, adjusted_project_path)

    registry = initialize_tool_registry(adjusted_project_path, log_dir)

    tools_to_adjust = [
        WriteFileTool.Name, # ensure files are created for listdirtool, ensure file content for grep tool
        ListDirectoryTool.Name, 
        GlobTool.Name,
        ShellTool.Name,
        EditTool.Name,
        GrepTool.Name,
    ]
    commands_to_execute = ["rm", "mkdir"]

    full_messages = trajectories[-1]["messages"]
    replacement_dict = {}
    for idx, message in enumerate(full_messages):
        if message["role"] == "system":
            if message["content"].startswith(SYSTEM_PREFIX):
                replacement_dict["system_prompt"] = get_core_system_prompt(adjusted_project_path)
        if message["role"] == "assistant":
            for tool_call in message.get("tool_calls", []):
                # print(tool_call["function"])
                if tool_call["function"]["name"] in tools_to_adjust:
                    if tool_call["function"]["name"] == ShellTool.Name:
                        tool_content = ""
                        for j in range(idx + 1, len(full_messages), 1):
                            if full_messages[j]["role"] == "tool":
                                tool_content = full_messages[j]["content"]
                                break
                        # find out if the command had failed in the original command, if not, execute it
                        if CMD_FAILED_SUB not in tool_content:
                            try:
                                output = safe_shell_execute(tool_call["function"]["arguments"].get("command", ""), adjusted_project_path)
                                # print(output)
                            except Exception as e:
                                print(str(e))
                    elif tool_call["function"]["name"] in [ListDirectoryTool.Name, GlobTool.Name, GrepTool.Name] and tool_call["function"]["arguments"].get("path", "") == adjusted_project_path:
                        tool_response = registry.execute_tool(tool_call["function"]["name"], tool_call["function"]["arguments"])
                        replacement_dict[str(message["msg_id"])] = str(tool_response["llmContent"])
                        # print(tool_response["llmContent"])
                    else:
                        tool_response = registry.execute_tool(tool_call["function"]["name"], tool_call["function"]["arguments"])
                        # print(tool_response["llmContent"])
    
    trajectories[-1]["replacement_dict"] = replacement_dict
    for traj in trajectories:
        traj["messages"] = replace_tool_outputs(traj["messages"], replacement_dict)
        traj["compressed_section"] = replace_tool_outputs(traj["compressed_section"], replacement_dict)
    
    return trajectories


def convert_compressed_message(content):
    lines = content.split("\n")
    filtered_lines = [line.replace("new project", "project") for line in lines if "old project" not in line]
    return "\n".join(filtered_lines)


TEMPLATE_ROOT = "templates"
FRONTEND_PREFIX = "--- User Instruction ---\n\nYou are tasked with implementing the frontend"
BACKEND_PREFIX = "--- User Instruction ---\n\nYou are tasked with implementing the backend"
VALIDATION_PREFIX = "Validate whether all the required features have been implemented. You must make sure that all features have been fully implemented and tested."
FRONTEND_VALIDATION_SUB = "- If the frontend plan has not been fully implemented, comtinue implementing it and make sure everything has been properly implemented."
SYSTEM_PREFIX = "Environment Context:"
COMPRESSED_PREFIX = "<COMPRESSED_HISTORY>"

def convert_messages(messages, working_dir, log_dir, frontend_plan, backend_plan, template, user_instruction, is_pure_frontend, adjusted_working_dir_root):
    i = 0
    old_project_path = os.path.join(working_dir, os.path.basename(working_dir).split("___")[-1])
    new_project_path = os.path.join(working_dir, "new_project")
    adjusted_project_path = os.path.join(adjusted_working_dir_root, os.path.basename(working_dir))
    replace_pairs = [("new project", "project"), ("old project", "project")]

    while i < len(messages):
        message = messages[i]
        if message["role"] == "user":
            if message["content"].startswith(BACKEND_PREFIX):
                message["content"] = f"--- User Instruction ---\n\n{user_instruction}\n\n--- Backend Plan ---\n\n{json.dumps(backend_plan, indent=2)}\n\n--- Backend Information ---\n\n{template['backend_instruction']}\n\nImplement the backend part of the project based on the User Instruction and the Backend Plan. You should **only** modify the backend part of the project."
            elif message["content"].startswith(FRONTEND_PREFIX):
                if is_pure_frontend:
                    message["content"] = f"--- User Instruction ---\n\n{user_instruction}\n\n--- Frontend Plan ---\n\n{json.dumps(frontend_plan, indent=2)}\n\n--- Frontend Information ---\n\n{template['frontend_instruction']}\n\nThis is a pure frontend project. Implement the frontend part of the project based on the User Instruction and the Frontend Plan. You should **only** modify the frontend part of the project. Do NOT create, modify, or reference the backend."
                else:
                    message["content"] = f"--- User Instruction ---\n\n{user_instruction}\n\nThe backend has already been implemented above.\n\n--- Frontend Plan ---\n\n{json.dumps(frontend_plan, indent=2)}\n\n--- Frontend Information ---\n\n{template['frontend_instruction']}\n\nImplement the frontend part of the project based on the User Instruction and the Frontend Plan. The backend APIs have already been implemented. You should **only** modify the frontend part of the project if possible. Do NOT modify the backend unless **absolutely necessary** and change as little as possible if you have to modify it."
            elif message["content"].startswith(VALIDATION_PREFIX):
                if FRONTEND_VALIDATION_SUB in message["content"]:
                    message["content"] = get_validation_prompt(is_frontend=True, user_instruction=user_instruction)
                else:
                    message["content"] = get_validation_prompt(is_frontend=False, user_instruction=user_instruction)
            message["content"] = message["content"].replace(new_project_path, adjusted_project_path)
            message["content"] = message["content"].replace(working_dir, adjusted_project_path)

        elif message["role"] == "assistant":
            should_pop = False
            for tool_call in message.get("tool_calls", []):
                if old_project_path in tool_call["function"]["arguments"].get("path", "") or old_project_path in tool_call["function"]["arguments"].get("command", "") or (tool_call["function"]["name"] == "list_directory" and tool_call["function"]["arguments"].get("path", "") == working_dir):
                    should_pop = True
                    break
                for k, v in tool_call["function"]["arguments"].items():
                    if isinstance(v, str):
                        tool_call["function"]["arguments"][k] = v.replace(new_project_path, adjusted_project_path)
                        tool_call["function"]["arguments"][k] = tool_call["function"]["arguments"][k].replace(working_dir, adjusted_project_path)
                
            if should_pop:
                popped_msg = messages.pop(i)
                while i < len(messages) and messages[i]["role"] == "tool":
                    messages.pop(i)
                continue
            else:
                if "content" in message.keys():
                    message["content"] = message["content"].replace(new_project_path, adjusted_project_path)
                    message["content"] = message["content"].replace(working_dir, adjusted_project_path)

                    for old, new in replace_pairs:
                        message["content"] = message["content"].replace(old, new)
        
        elif message["role"] == "system":
            if "content" in message.keys():
                message["content"] = message["content"].replace(new_project_path, adjusted_project_path)
                message["content"] = message["content"].replace(working_dir, adjusted_project_path)
                if message["content"].startswith(COMPRESSED_PREFIX):
                    message["content"] = convert_compressed_message(message["content"])

        elif message["role"] == "tool":
            if "content" in message.keys():
                message["content"] = message["content"].replace(new_project_path, adjusted_project_path)
                message["content"] = message["content"].replace(working_dir, adjusted_project_path)
                message["content"] = message["content"].replace("new_project/", "")
        i += 1

    return messages


Trajectory = Dict[str, Optional[object]]   # {"step": int, "kind": str, "messages": list, "tools": list|None}


def get_trajectories(log_dir: str | Path) -> List[Trajectory]:
    log_dir = Path(log_dir)
    if not log_dir.is_dir():
        raise FileNotFoundError(log_dir)

    traj: List[Trajectory] = []

    comp_re = re.compile(r"^(\d+)_compress_history\.json$")
    compress_steps: List[int] = sorted(
        int(m.group(1)) for p in log_dir.iterdir()
        if (m := comp_re.match(p.name))
    )

    curr_background_color = None
    curr_component_color = None
    tools = None

    compressed_section = []
    new_messages_start = 0
    idx = 0
    full_messages = []
    for n in compress_steps:
        prev = n - 1

        messages, tools, background_color, component_color, is_tool = get_messages_tools(log_dir, prev)
        if background_color is not None and component_color is not None:
            curr_background_color, curr_component_color = background_color, component_color

        if (Path(log_dir) / f"{n}_llm_response.json").is_file():
            after_compress_messages, _, _, _, is_tool = get_messages_tools(log_dir, n)
        else:
            # in case is compression due to exceeding context length
            after_compress_messages, _, _, _, is_tool = get_messages_tools(log_dir, n + 1)

        if is_tool:
            compress_length = len(messages) - len(after_compress_messages) + 3 # newly added tools is 2, compress message is 1
        else:
            compress_length = len(messages) - len(after_compress_messages) + 2 # newly added tools is 1, compress message is 1

        k = 0
        if new_messages_start > 0:
            for i in range(new_messages_start - 1, -1, -1):
                # mark as old message
                messages[i]["is_old"] = True

                # mark old assistant messages with consistant indexes
                if messages[i]["role"] == "assistant" and i >= 4:
                    messages[i]["msg_id"] = idx - k
                    k += 1
        
        for i in range(new_messages_start, len(messages), 1):
            # add new indexes for new messages
             if messages[i]["role"] == "assistant" and i >= 3:
                idx += 1
                messages[i]["msg_id"] = idx

        full_messages.extend(messages[new_messages_start:])

        traj.append(
            dict(
                step=prev,
                kind="compress_before",
                messages=messages,
                tools=tools,
                compressed_section=compressed_section,
            )
        )

        compressed_section = copy.deepcopy(messages[3: 3 + compress_length]) # the part compressed in the next message
        if is_tool:
            new_messages_start = len(after_compress_messages) - 2
        else:
            new_messages_start = len(after_compress_messages) - 1
        

    llm_re = re.compile(r"^(\d+)_llm_response\.json$")
    llm_steps = sorted(
        int(m.group(1)) for p in log_dir.iterdir()
        if (m := llm_re.match(p.name))
    )
    if llm_steps:                       # the directory *should* have at least one
        last_llm = llm_steps[-1]
        messages, tools, background_color, component_color, is_tool = get_messages_tools(log_dir, last_llm)
        if background_color is not None and component_color is not None:
            curr_background_color, curr_component_color = background_color, component_color

        k = 0
        if new_messages_start > 0:
            for i in range(new_messages_start - 1, -1, -1):
                # mark as old message
                messages[i]["is_old"] = True

                # mark old assistant messages with consistant indexes
                if messages[i]["role"] == "assistant" and i >= 4:
                    messages[i]["msg_id"] = idx - k
                    k += 1
        
        for i in range(new_messages_start, len(messages), 1):
            # add new indexes for new messages
             if messages[i]["role"] == "assistant" and i >= 3:
                idx += 1
                messages[i]["msg_id"] = idx

        full_messages.extend(copy.deepcopy(messages[new_messages_start:]))

        traj.append(
            dict(
                step=last_llm,
                kind="llm_final",
                messages=messages,
                tools=tools,
                compressed_section=compressed_section,
            )
        )

    traj.append(
        dict(
            step=-1,
            kind="uncompressed_messages",
            messages=full_messages,
            tools=tools,
            compressed_section=[],
        )
    )

    return traj, curr_background_color, curr_component_color


def process_trajectories(trajectories, working_dir, log_dir, frontend_plan, backend_plan, template, user_instruction, is_pure_frontend, adjusted_working_dir_root):
    adjusted_project_path = os.path.join(adjusted_working_dir_root, os.path.basename(working_dir))
    for traj in trajectories:
        traj["log_dir"] = log_dir
        traj["orig_working_dir"] = working_dir
        traj["working_dir"] = adjusted_project_path
        traj["messages"] = convert_messages(traj["messages"], working_dir, log_dir, frontend_plan, backend_plan, template, user_instruction, is_pure_frontend, adjusted_working_dir_root)
        traj["compressed_section"] = convert_messages(traj["compressed_section"], working_dir, log_dir, frontend_plan, backend_plan, template, user_instruction, is_pure_frontend, adjusted_working_dir_root)

    trajectories = adjust_tool_output(trajectories, adjusted_project_path, log_dir, template)
    return trajectories


def extract_colors(text: str) -> Optional[Tuple[str, str]]:
    pattern = re.compile(
        r"""
        Make\ the\ background\ color\s+
        (?P<bg>.+?)                        # capture background colour (lazy)
        \s+and\ the\ compoment\ color\s+
        (?P<comp>.+?)                      # capture component colour (lazy)
        (?=[.?!]|$)                        # stop just before ., ?, ! or EOS
        """,
        re.IGNORECASE | re.VERBOSE,
    )

    m = pattern.search(text)
    if m:
        bg_color   = m.group("bg").strip()
        comp_color = m.group("comp").strip()
        return bg_color, comp_color

    return None, None


def find_colors(trajectories):
    background_color, component_color = None, None
    for traj in trajectories:
        for message in traj["messages"]:
            if message["role"] == "user":
                bg_color, cp_color = extract_colors(message["content"])
                if bg_color is not None:
                    background_color, component_color = bg_color, cp_color
                    break
    return background_color, component_color


def _process_one_sample(
    info_data: Dict[str, Any],
    *,
    log_root_dir: str,
    adjusted_working_dir_root: str,
    template: Dict[str, Any],
) -> List[Trajectory] | None:
    """
    Execute the original per-sample pipeline.
    In case of *any* exception we convert it into a picklable form so the parent
    process never sees un-picklable objects such as os.DirEntry.
    """
    try:
        # ---------------------------------------------------------------------
        # original code (unchanged)
        # ---------------------------------------------------------------------
        log_dir = os.path.join(log_root_dir, info_data["id"])
        if not os.path.isfile(os.path.join(log_dir, "finished.json")):
            return None

        trajectories, background_color, component_color = get_trajectories(log_dir)

        if background_color is None or component_color is None:
            background_color, component_color = find_colors(trajectories)
            if background_color is None or component_color is None:
                return None

        working_dir = log_dir.replace("logs_root", "workspaces_root")
        frontend_plan   = info_data["summary"]["frontendPlan"]
        backend_plan    = info_data["summary"]["backendPlan"]
        user_instruction = (
            info_data["summary"]["userInstruction"]
            + " "
            + wrap_color_theme(background_color, component_color)
        )

        is_pure_frontend = (
            backend_plan is None or len(backend_plan.get("apiEndpoints", [])) == 0
        )

        trajectories = process_trajectories(
            trajectories,
            working_dir,
            log_dir,
            frontend_plan,
            backend_plan,
            template,
            user_instruction,
            is_pure_frontend,
            adjusted_working_dir_root,
        )
        return trajectories

    # -------------------------------------------------------------------------
    # catch *everything*, make it picklable
    # -------------------------------------------------------------------------
    except Exception as exc:              # noqa: BLE001,E722
        # optional: write a short message so you know which sample failed
        sys.stderr.write(
            f"[worker] sample {info_data.get('id', '<unknown>')} failed: {exc}\n"
        )
        sys.stderr.flush()

        # Either return None (to silently skip) …
        return None

        # … or, if you prefer the main process to see an error, re-raise a
        # picklable wrapper:
        # raise RuntimeError(str(exc)) from None


def get_train(
    log_root_dir: str,
    info_path: str,
    out_path: str,
    adjusted_working_dir_root: str,
    *,
    template_name: str = "nextjs-nextjs-postresql",
    max_workers: int | None = None,         # let the caller tune the pool size
) -> None:
    # -------------------------------------------------------------------------
    # locate the template once; we will broadcast it to the workers
    # -------------------------------------------------------------------------
    template = next(
        temp for temp in TEMPLATES["templates"] if temp["name"] == template_name
    )

    info_datas = load_jsonl(info_path)

    datas: List[Trajectory] = []
    finished_samples = 0

    process_one = partial(
        _process_one_sample,
        log_root_dir=log_root_dir,
        adjusted_working_dir_root=adjusted_working_dir_root,
        template=template,
    )

    # -------------------------------------------------------------------------
    # fire up the worker pool
    # -------------------------------------------------------------------------
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(process_one, info): info for info in info_datas}

        with tqdm(total=len(futures), desc="processing samples") as pbar:
            for fut in as_completed(futures):
                result = fut.result()
                if result:
                    # each successful result is a list[Trajectory]
                    datas.extend(result)
                    finished_samples += 1
                pbar.update(1)

    print(f"\n\nfinished samples: {finished_samples}")
    for i in range(len(datas)):
        new_working_dir = convert_messages_replace_working_dir(datas[i]["messages"], datas[i]["working_dir"], i)
        new_working_dir = convert_messages_replace_working_dir(datas[i]["compressed_section"], datas[i]["working_dir"], i)
    save_jsonl(datas, out_path)


def main() -> None:
    log_root_dir = "logs_root/model-Qwen3-Coder-480B-A35B-Instruct-FP8_hist-100_iter-500_compress-0.5_val-1_sum-5_nestjs_backtrans4"
    info_path = "src/run_process_data/jsonl_files/nestjs_github-repos_filtered-with-info.jsonl"
    out_path = "src/run_process_data/jsonl_files/nestjs_github-repos_filtered-with-info_backtranslated.jsonl"
    adjusted_working_dir_root = "workspaces_root/model-Qwen3-Coder-480B-A35B-Instruct-FP8_hist-100_iter-500_compress-0.5_val-1_sum-5_nestjs_backtrans4_adjusted1"
    max_workers = 16
    get_train(log_root_dir, info_path, out_path, adjusted_working_dir_root, template_name="nextjs-nextjs-postresql", max_workers=max_workers)


if __name__ == "__main__":
    main()