"""
Usage:
  python3 refine_tasks.py --dirs ./generated_tasks 

"""

import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import List

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

REFINE_INSTRUCTION_PROMPT = """You will be given a task instruction. Please refine it with the following requirements:
1. Imitate the speech style of a human user. Make it more natural and diverse.
2. Remove specific hints on how to achieve the task. (e.g. press which button, click which link, etc.)
3. The task should remain unambiguous and clear, and the task's goal must remain the same.

For example, you should refine:
"Use the \"Price range\" filter on the Ryanair website to limit the flight options to \u20ac50-\u20ac100."
into:
"Find flights from London to Paris with a price range of \u20ac50-\u20ac100."
or
"I only have a limited budget for flights. Could you help find flights from London to Paris with a price range of \u20ac50-\u20ac100?"

You should refine:
"Import email messages from another email program by clicking the \"Import\" button under the \"Import from Another Program\" section."
into:
"Please import email messages from another email program."
or
"I have some emails in my other email program. Could you help me import them into Thunderbird?"

You should refine:
"Click on the "About Us" link to learn more about the company's history and mission."
into:
"Please provide information about the company's history and mission."
or
"Find an "About Us' or similar page on the website that describes the company's history and mission."
The original instruction is:
{instruction}

Refine it and return the refined instruction text in this exact format:
ORIGINAL: <the original instruction>
REFINED: <the refined instruction>
REASONING: <your reasoning about the changes you made>

"""

def rewrite_task_description(llm, original_text: str, prompt_template: str) -> str:
    """
    Call the LLM to rewrite a single task description.
    Returns the rewritten string (fallback to original on failure).
    """
    try:
        prompt = REFINE_INSTRUCTION_PROMPT
        prompt = prompt.replace("{instruction}", original_text)
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt}
                ]
            }
        ]
        resp = llm.chat(messages)
        new_text = getattr(resp, "content", "") or ""
        new_text = new_text.strip()
        if "REFINED:" in new_text:
            new_text = new_text.split("REFINED:")[1].split("REASONING:")[0].strip()
        print(f"Original: {original_text} | Rewritten: {new_text}")
        if not new_text:
            return original_text
        return new_text
    except Exception as e:
        logger.warning(f"LLM call failed for text [{original_text[:60]}...]: {e}")
        return original_text


def process_file(llm, filepath: Path, prompt_template: str, out_suffix: str, sleep_s: float):
    logger.info(f"Processing {filepath}")
    try:
        with filepath.open('r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"Failed to load JSON {filepath}: {e}")
        return

    if not isinstance(data, list):
        logger.warning(f"Skipping {filepath} because JSON root is not a list")
        return

    modified = False
    for idx, obj in enumerate(data):
        if not isinstance(obj, dict):
            continue
        if 'task_description' in obj and isinstance(obj['task_description'], str):
            original = obj['task_description']
            # skip empty descriptions
            if not original.strip():
                continue
            new_desc = rewrite_task_description(llm, original, prompt_template)
            if new_desc != original:
                obj['task_description'] = new_desc
                modified = True
            # small pause
            if sleep_s > 0:
                time.sleep(sleep_s)

    out_path = filepath.with_name(filepath.stem + out_suffix + filepath.suffix)
    # always write output (even if unchanged) so user has consolidated rewritten files
    try:
        with out_path.open('w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
        logger.info(f"Wrote rewritten file to {out_path} (modified={modified})")
    except Exception as e:
        logger.error(f"Failed to write output {out_path}: {e}")


def find_json_files(dirs: List[Path]) -> List[Path]:
    files: List[Path] = []
    for d in dirs:
        if not d.exists() or not d.is_dir():
            logger.warning(f"Directory does not exist or is not a directory: {d}")
            continue
        for p in sorted(d.glob("tasks_*.json")):
            files.append(p)
    return files


def main():
    parser = argparse.ArgumentParser(description="Rewrite task_description fields using local LLM")
    parser.add_argument("--dirs", default=["./memory_evolution/generated_tasks"], nargs='+',
                        help="Directories to scan for tasks_*.json (space separated)")
    parser.add_argument("--out-suffix", default="_rewritten", help="Suffix to append to output filename (default: _rewritten)")
    parser.add_argument("--sleep", type=float, default=0.5, help="Seconds to sleep between LLM calls (default 0.5)")
    args, remaining = parser.parse_known_args()
    # Remove the args we parsed so later imports won't choke on them
    sys.argv = [sys.argv[0]] + remaining
    
    from agent.llm_config import load_tool_llm
    from config.argument_parser import config
    cfg = config()
    llm = load_tool_llm(cfg)

    prompt_template = REFINE_INSTRUCTION_PROMPT
    dirs = [Path(d) for d in args.dirs]
    json_files = find_json_files(dirs)
    if not json_files:
        logger.error("No matching tasks_*.json files found in provided directories.")
        return

    logger.info(f"Found {len(json_files)} files to process.")
    for fpath in json_files:
        process_file(llm, fpath, prompt_template, args.out_suffix, args.sleep)

    logger.info("Done.")


if __name__ == "__main__":
    main()