import os
import re
import json
import csv
import math
import mimetypes
import hashlib
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
import logger
import requests
import pandas as pd
import numpy as np
from datetime import datetime

from .base import BaseTool, ToolResult, ToolFailure


URL_RE = re.compile(r"https?://[^\s]+", re.I)


def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def _safe_filename(name: str) -> str:
    return re.sub(r"[^\w\-\.]+", "_", name).strip("._") or "file"


def _guess_ext_from_content_type(ct: Optional[str]) -> Optional[str]:
    if not ct:
        return None
    guess = mimetypes.guess_extension(ct.split(";")[0].strip())
    return guess


def _infer_name_meaning(fname: str) -> List[str]:
    hints = []
    if re.search(r"\b(train|test|valid|dev)\b", fname, re.I):
        hints.append("数据分割文件")
    if re.search(r"\b(202\d|19\d{2}|20\d{2})([-_]?\d{1,2})([-_]?\d{1,2})\b", fname):
        hints.append("文件名包含日期信息")
    if re.search(r"\b(v\d+)\b", fname, re.I):
        hints.append("版本化文件")
    if re.search(r"\b(sample|small|mini)\b", fname, re.I):
        hints.append("样本/小规模文件")
    return hints


class DataDownloader(BaseTool):
    """检索并下载公开数据资源。

    支持：关键词检索（优先可靠来源）或直接 URL 下载；类型过滤；基本完整性校验；返回元信息。
    """

    name: str = "data_downloader"
    description: str = (
        "根据关键词或直链下载公开数据资源；支持类型过滤与完整性校验，返回本地路径及元信息。"
    )
    input: str = "keywords_or_url: 关键词或URL; download_dir: 本地保存目录; file_types: 允许的扩展名数组; max_files: 最大下载数"
    output: str = "{downloads: [{path,size,format,source_url}], errors: [..]}"
    parameters: dict = {
        "type": "object",
        "properties": {
            "keywords_or_url": {"type": "string", "description": "搜索关键词或直接URL"},
            "download_dir": {"type": "string", "description": "保存目录，默认 workspace/"},
            "file_types": {
                "type": "array",
                "items": {"type": "string"},
                "description": "允许的扩展名（不带点），如 ['csv','zip','json']",
            },
            "max_files": {"type": "integer", "minimum": 1, "maximum": 20, "description": "最多下载数量，默认3"},
        },
        "required": ["keywords_or_url"],
    }
    notices: List[str] = [
        "Kaggle/部分源可能需要认证；当前实现不处理认证，仅下载公开可访问链接。",
        "若需要特定来源限制，可在关键词中加入 site:kaggle.com 等。",
    ]
    examples: List[str] = [
        "下载与 Iris 相关的 CSV 数据到 workspace/data 目录并限制 csv/json：'Iris dataset csv'",
    ]

    def _search_candidates(self, query: str, limit: int) -> List[Dict[str, str]]:
        # 复用现有 WebSearch 工具进行泛搜索
        try:
            from .web_tools import WebSearch
            ws = WebSearch()
            res = ws.execute(query=query, max_results=limit)
            items = res.output.get("results", []) if res and not res.error else []
            # 统一结构 {title, url}
            return [{"title": it.get("title"), "url": it.get("url")} for it in items]
        except Exception:
            return []

    def _should_accept_url(self, url: str, allowed_exts: List[str]) -> bool:
        if not allowed_exts:
            return True
        for ext in allowed_exts:
            if url.lower().split("?")[0].endswith("." + ext.lower()):
                return True
        return False

    def _download_one(self, url: str, dst_dir: Path, allowed_exts: List[str]) -> Dict[str, Any]:
        out: Dict[str, Any] = {"source_url": url}
        try:
            headers = {"User-Agent": "Mozilla/5.0 (compatible; AutoDS/1.0)"}
            with requests.get(url, headers=headers, timeout=20, stream=True) as r:
                r.raise_for_status()
                ct = r.headers.get("Content-Type")
                name = _safe_filename(Path(url.split("?")[0]).name or "download")
                ext = Path(name).suffix
                if not ext and ct:
                    guess = _guess_ext_from_content_type(ct)
                    if guess:
                        name += guess

                if allowed_exts:
                    if not any(name.lower().endswith("." + e.lower()) for e in allowed_exts):
                        out["error"] = f"format not allowed: {name}"
                        return out

                _ensure_dir(dst_dir)
                fp = dst_dir / name
                # 去除fp路径前的 workspace/ 保持相对路径
                h = hashlib.sha256()
                size = 0
                with open(fp, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        if not chunk:
                            continue
                        f.write(chunk)
                        h.update(chunk)
                        size += len(chunk)
                if size <= 0:
                    out["error"] = "empty file"
                    return out
                fp = fp.resolve().relative_to(Path("workspace").resolve())     
                out.update({
                    "path": str(fp),
                    "size": size,
                    "format": (ext[1:] if ext.startswith(".") else ext) or (Path(fp).suffix[1:] or ""),
                    "sha256": h.hexdigest(),
                })
                return out
        except Exception as e:
            out["error"] = f"download error: {e}"
            return out

    def execute(self, keywords_or_url: str, download_dir: Optional[str] = "workspace", file_types: Optional[List[str]] = None, max_files: int = 3) -> ToolResult:
        try:
            download_dir = download_dir or "workspace"
            allowed = [e.strip().lstrip(".") for e in (file_types or []) if e and isinstance(e, str)]
            max_files = max(1, min(int(max_files or 3), 20))
            dst_dir = Path(download_dir)

            downloads: List[Dict[str, Any]] = []
            errors: List[str] = []

            # 直链下载
            if URL_RE.search(keywords_or_url):
                if self._should_accept_url(keywords_or_url, allowed):
                    info = self._download_one(keywords_or_url, dst_dir, allowed)
                    if "error" in info:
                        errors.append(json.dumps(info, ensure_ascii=False))
                    else:
                        downloads.append(info)
                else:
                    errors.append("url format not allowed by file_types")
                return ToolResult(output={"downloads": downloads, "errors": errors})

            # 关键词检索
            candidates = self._search_candidates(keywords_or_url, max_files * 3)
            for c in candidates:
                u = c.get("url")
                if not u:
                    continue
                if not self._should_accept_url(u, allowed):
                    continue
                info = self._download_one(u, dst_dir, allowed)
                if "error" in info:
                    errors.append(json.dumps(info, ensure_ascii=False))
                    continue
                downloads.append(info)
                if len(downloads) >= max_files:
                    break

            if not downloads and not errors:
                errors.append("no candidates found")

            return ToolResult(output={"downloads": downloads, "errors": errors})
        except Exception as e:
            return ToolFailure(error=f"downloader failed: {e}")


class DataExplorer(BaseTool):
    """对目录或单个文件进行结构化探索分析。支持 CSV/Parquet/Excel/JSON/TXT。"""

    name: str = "data_explorer"
    description: str = (
        "分析目录/文件的数据结构与字段特征，输出文件清单、字段名称/类型/示例及潜在问题提示；对数据集进行初步探索，在不清楚文件或文件夹结构时可用。"
    )
    input: str = "path: 目录或文件"
    output: str = "{files:[{name,size,ext,hints[:]}], schemas:{filename:[{name,dtype,example,hint?}]}}"
    parameters: dict = {
        "type": "object",
        "properties": {
            "path": {"type": "string", "description": "目录或文件路径"},
            # "sample_ratio": {"type": "number", "minimum": 0.0, "maximum": 1.0, "description": "采样比例，默认0.05"},
            # "max_fields": {"type": "integer", "minimum": 1, "maximum": 2000, "description": "字段数上限，默认200"},
            # "file_limit": {"type": "integer", "minimum": 1, "maximum": 2000, "description": "目录文件上限，默认50"},
        },
        "required": ["path"],
    }
    notices: List[str] = [
        "Excel/Parquet 解析可能需要 openpyxl/pyarrow；若未安装将跳过并给出提示。",
        "JSON 自动尝试记录级加载，若结构复杂建议提供示例或字段路径。",
    ]

    def _list_files(self, base: Path, limit: int) -> List[Path]:
        allowed = {".csv", ".parquet", ".xlsx", ".xls", ".json", ".txt"}
        files = []
        for p in sorted(base.rglob("*")):
            if p.is_file() and p.suffix.lower() in allowed:
                files.append(p)
            if len(files) >= limit:
                    break
        return files

    def _sample_read(self, fp: Path, sample_ratio: float, max_rows: int = 20) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
        load_params = {}
        try:
            ext = fp.suffix.lower()
            if ext in {".csv", ".txt"}:
                # Try sniffing
                try:
                    with open(fp, "r", encoding="utf-8", errors="ignore") as f:
                        # 读取完整的几行，而不是固定字节，以提高 sniff 准确率
                        sample_lines = []
                        for _ in range(10):
                            line = f.readline()
                            if not line: break
                            sample_lines.append(line)
                        sample_text = "".join(sample_lines)
                    
                    if sample_text:
                        sniffer = csv.Sniffer()
                        try:
                            dialect = sniffer.sniff(sample_text)
                            # Handle tab explicitly
                            sep = dialect.delimiter
                            load_params["sep"] = "\\t" if sep == "\t" else sep
                        except csv.Error:
                            if ext == ".txt":
                                # .txt with failed sniffing -> try python engine with auto sep
                                load_params["sep"] = None
                                load_params["engine"] = "python"
                        
                        try:
                            # 显式记录 header 状态：0 表示有表头(第一行)，None 表示无表头
                            if sniffer.has_header(sample_text):
                                load_params["header"] = 0
                            else:
                                load_params["header"] = None
                        except csv.Error:
                            pass
                except Exception:
                    pass

                # Build read args
                read_kwargs = {}
                if "sep" in load_params:
                     sep = load_params["sep"]
                     if sep is None and load_params.get("engine") == "python":
                         read_kwargs["sep"] = None
                         read_kwargs["engine"] = "python"
                     else:
                        if sep == "\\t": sep = "\t"
                        read_kwargs["sep"] = sep

                # 显式传递 header 参数
                if "header" in load_params:
                    read_kwargs["header"] = load_params["header"]

                try:
                    df = pd.read_csv(fp, nrows=max_rows, **read_kwargs)
                except pd.errors.ParserError:
                    # Fallback skipping rows
                    with open(fp, "r", encoding="utf-8", errors="ignore") as f:
                        lines = f.readlines()
                    df = None
                    for skip in range(1, min(10, len(lines))):
                        try:
                            # Use simple sep if defined, else None
                            s = read_kwargs.get("sep")
                            df = pd.read_csv(fp, skiprows=skip, nrows=max_rows, sep=s, engine="python" if s is None else None)
                            if df.shape[1] > 1:
                                load_params["skiprows"] = skip
                                break
                        except pd.errors.ParserError:
                            continue
                    if df is None:
                        raise

            elif ext in {".xlsx", ".xls"}:
                try:
                    df = pd.read_excel(fp)
                except Exception:
                    return None, {}
            elif ext == ".parquet":
                try:
                    df = pd.read_parquet(fp)
                except Exception:
                    return None, {}
            elif ext == ".json":
                try:
                    df = pd.read_json(fp, lines=True)
                    load_params["lines"] = True
                except ValueError:
                    try:
                        df = pd.read_json(fp)
                    except Exception:
                        return None, {}
            elif ext == ".txt":
                 # Should be handled above, but just in case
                return None, {}
            else:
                return None, {}

            if df is None or df.empty:
                return df, load_params
            
            if 0 < sample_ratio < 1.0 and len(df) > 0:
                frac = max(sample_ratio, min(1.0, max_rows / max(len(df), 1)))
                frac = min(frac, 1.0)
                df = df.sample(frac=frac, random_state=42) if 0 < frac < 1 else df
            return df, load_params
        except Exception as e:
            logger.warning(f"Error reading file {fp}: {e}")
            return None, {}

    def _example_value(self, s: pd.Series) -> Any:
        for v in s.head(10):
            # Safe check for pd.notna(v) to handle list/array values
            is_not_na = False
            try:
                check = pd.notna(v)
                if isinstance(check, (bool, np.bool_)):
                    is_not_na = check
                else:
                    is_not_na = True # Arrays/lists are treated as valid data
            except Exception:
                is_not_na = True

            if is_not_na:
                try:
                    return v.item() if hasattr(v, "item") else v
                except Exception:
                    return str(v)
        return None

    def _field_hint(self, name: str) -> Optional[str]:
        if re.fullmatch(r"col\d+", name, re.I):
            return "字段名像自动生成（colN），建议重命名"
        if len(name) <= 2:
            return "字段名过短，含义可能不明确"
        if re.search(r"\s", name):
            return "字段名包含空白，建议改为下划线或驼峰"
        if re.search(r"[^A-Za-z0-9_\-]", name):
            return "字段名包含特殊字符，可能影响加载或SQL兼容性"
        return None

    def _infer_meanings_with_llm(self, files: List[Dict[str, Any]], schemas: Dict[str, List[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
        """使用一次 LLM 调用同时推断文件命名含义与字段含义。

        返回结构：
        {
          "file_meanings": { filename: meaning },
          "field_meanings": { filename: { field: meaning } }
        }
        失败返回 None。
        """
        try:
            # 延迟导入，避免注册阶段循环依赖
            from llm import LLM
        except Exception as e:
            print(e)
            return None

        # 只取必要信息，控制令牌
        compact = {
            "files": [
                {
                    "name": f.get("name"),
                    "ext": f.get("ext"),
                    "size": f.get("size"),
                }
                for f in files
            ],
            "schemas": {
                fn: [
                    {"name": c.get("name"), "dtype": c.get("dtype"), "example": c.get("example")}
                    for c in (cols or [])
                ]
                for fn, cols in schemas.items()
            },
        }
        system = (
            "你是数据理解助手。请基于文件名、扩展名、大小以及示例字段信息，推断：\n"
            "1) 每个文件名的含义（例如是否为训练/测试、是否包含日期或版本等）；\n"
            "2) 每个字段的业务含义或直观解释（尽量简洁准确）。\n"
            "仅输出 JSON，格式：{\"file_meanings\":{..},\"field_meanings\":{\"<filename>\":{\"<col>\":\"<meaning>\"}}}。不要输出多余文本。"
        )
        user = json.dumps(compact, ensure_ascii=False)

        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ]
        try:
            resp = LLM().ask(messages=messages, tools=[])
            # 解析 JSON
            content = resp.choices[0].message.content if hasattr(resp, "choices") else None
            if not content:
                return None
            # 去除代码围栏
            content_str = content.strip()
            if content_str.startswith("```"):
                content_str = content_str.strip("`\n").split("\n", 1)[-1]
            try:
                data = json.loads(content_str)
                if isinstance(data, dict):
                    return data
            except Exception:
                # 尝试查找 ```json ... ``` 片段
                m = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", content, re.I)
                if m:
                    try:
                        data = json.loads(m.group(1))
                        if isinstance(data, dict):
                            return data
                    except Exception:
                        pass
            return None
        except Exception:
            return None

    def execute(self, path: str, sample_ratio: Optional[float] = 0.05, max_fields: Optional[int] = 100, file_limit: Optional[int] = 20) -> ToolResult:
        try:
            workspace_root = Path("workspace").resolve()
            p = (workspace_root / path).resolve()
            
            sample_ratio = float(sample_ratio or 0.05)
            max_fields = int(max_fields or 200)
            file_limit = int(file_limit or 20)

            files: List[Path] = []
            if p.is_dir():
                files = self._list_files(p, file_limit)
                # 如果发现文件过少，可能存在非结构化数据目录（如 images），补充列出子目录以提示结构
                if len(files) < 10:
                    try:
                        # 添加一级子目录供探索
                        subdirs = [x for x in p.iterdir() if x.is_dir() and not x.name.startswith('.')]
                        # 取前10个子目录
                        files.extend(subdirs[:10])
                    except Exception:
                        pass
            elif p.is_file():
                files = [p]
            else:
                return ToolFailure(error=f"path not found: {path}")

            # 去重并限制数量
            unique_files = []
            seen = set()
            for f in files:
                if f not in seen:
                    unique_files.append(f)
                    seen.add(f)
            files = unique_files[:15]

            file_summaries = []
            schemas: Dict[str, List[Dict[str, Any]]] = {}
            
            for fp in files:
                try:
                    if fp.is_dir():
                        # 如果是目录，列出部分子内容作为示例
                        try:
                            children = [x.name for x in fp.iterdir() if not x.name.startswith('.')]
                            total_children = len(children)
                            children_preview = children[:5]
                            if total_children > 5:
                                children_preview.append(f"...({total_children} items)")
                            example_str = ", ".join(children_preview)
                        except Exception:
                            example_str = "access denied"

                        file_entry = {
                            "name": fp.name,
                            "path": str(fp.relative_to(workspace_root)),
                            "size": 0,
                            "ext": "dir",
                        }
                        file_summaries.append(file_entry)
                        schemas[fp.name] = [{"name": "contents", "dtype": "dir_items", "example": example_str}]
                    else:
                        # 如果是文件
                        size = fp.stat().st_size
                        ext = fp.suffix.lower()
                        file_entry = {
                            "name": fp.name,
                            "path": str(fp.relative_to(workspace_root)),
                            "size": size,
                            "ext": ext.lstrip('.'),
                        }
                        file_summaries.append(file_entry)

                        df, load_params = self._sample_read(fp, sample_ratio)
                        if load_params:
                            # 转换参数名以更清晰表达文件信息
                            file_info = {}
                            if "sep" in load_params:
                                file_info["delimiter"] = load_params["sep"]
                            if "header" in load_params:
                                # header=None 意味着无表头
                                file_info["has_header"] = (load_params["header"] is not None)
                            elif df is not None:
                                # 默认 pandas read_csv 是 infer，通常有表头
                                file_info["has_header"] = True
                            
                            if "skiprows" in load_params:
                                file_info["skip_rows"] = load_params["skiprows"]
                            if "lines" in load_params: # for json
                                file_info["json_lines"] = load_params["lines"]
                            
                            file_entry["file_info"] = file_info

                        cols: List[Dict[str, Any]] = []
                        if df is not None and not df.empty:
                            for col in list(df.columns)[:max_fields]:
                                s = df[col]
                                example = self._example_value(s)
                                
                                is_example_na = False
                                if example is None:
                                        is_example_na = True
                                else:
                                    try:
                                        check = pd.isna(example)
                                        if isinstance(check, (bool, np.bool_)):
                                            is_example_na = check
                                        else:
                                            is_example_na = False
                                    except Exception:
                                        is_example_na = False

                                cols.append({
                                    "name": str(col),
                                    "dtype": str(s.dtype),
                                    "example": None if is_example_na else example,
                                })
                        else:
                            cols.append({"name": "", "dtype": "", "example": None})
                        schemas[fp.name] = cols
                except Exception as e:
                    logger.warning(f"Error processing {fp}: {e}")
                    # 即使失败也尝试记录基本信息
                    if fp.name not in schemas:
                         schemas[fp.name] = [{"name": "error", "dtype": "error", "example": str(e)}]

            # 使用一次 LLM 同时推断文件与字段含义
            infer = self._infer_meanings_with_llm(file_summaries, schemas)
            if infer and isinstance(infer, dict):
                fmean = infer.get("file_meanings") or {}
                cmean = infer.get("field_meanings") or {}
                # 填充文件 meaning
                for f in file_summaries:
                    m = fmean.get(f.get("name")) if isinstance(fmean, dict) else None
                    if m:
                        f["meaning"] = str(m)
                    else:
                        hints = _infer_name_meaning(f.get("name") or "")
                        if hints:
                            f["meaning"] = "; ".join(hints)
                # 填充字段 meaning
                for fn, cols in schemas.items():
                    mapping = cmean.get(fn) if isinstance(cmean, dict) else None
                    for c in cols:
                        nm = c.get("name")
                        if not nm:
                            c["meaning"] = "无法解析或空文件/工作表"
                            continue
                        val = mapping.get(nm) if isinstance(mapping, dict) else None
                        if val:
                            c["meaning"] = str(val)
                        else:
                            hint = self._field_hint(str(nm))
                            if hint:
                                c["meaning"] = hint
            else:
                for f in file_summaries:
                    hints = _infer_name_meaning(f.get("name") or "")
                    if hints:
                        f["meaning"] = "; ".join(hints)
                for _, cols in schemas.items():
                    for c in cols:
                        nm = c.get("name")
                        if not nm:
                            c["meaning"] = "无法解析或空文件/工作表"
                        else:
                            hint = self._field_hint(str(nm))
                            if hint:
                                c["meaning"] = hint

            return ToolResult(output={"files": file_summaries, "schemas": schemas})
        except Exception as e:
            print(e)
            return ToolFailure(error=f"explorer failed: {e}")



class PrepareBuiltinDataset(BaseTool):
    """将 sklearn 内置数据集导出为本地文件（CSV/Parquet），便于后续 data_explorer 使用。

    支持数据集：iris, wine, breast_cancer, diabetes, digits, linnerud, california_housing。
    可选保存格式 csv|parquet（默认 csv），默认保存到 workspace/。
    """

    name: str = "prepare_builtin_dataset"
    description: str = (
        "导出 sklearn 内置数据集为本地 CSV/Parquet 文件，返回保存路径、行列数。"
    )
    input: str = (
        "dataset_name: 数据集名称(iris|wine|breast_cancer|diabetes|digits|linnerud|california_housing); "
        "save_dir?: 保存目录，默认 workspace; format?: csv|parquet; file_basename?: 文件基础名"
    )
    output: str = "{dataset: 名称, path: 文件路径, rows: 行数, cols: 列数, format: 保存格式}"
    parameters: dict = {
        "type": "object",
        "properties": {
            "dataset_name": {"type": "string", "description": "sklearn 内置数据集名称"},
            "save_dir": {"type": "string", "description": "保存目录，默认 workspace"},
            "format": {"type": "string", "enum": ["csv", "parquet"], "description": "保存格式，默认 csv"},
            "file_basename": {"type": "string", "description": "输出文件基础名（可选）"},
        },
        "required": ["dataset_name"],
    }

    def _load_dataset(self, name: str) -> pd.DataFrame:
        name_norm = (name or "").strip().lower().replace("-", "_")
        try:
            from sklearn import datasets
        except Exception as e:
            raise RuntimeError(f"scikit-learn 不可用: {e}")

        # 定义加载器，优先尝试 as_frame=True，不可用则回退到 numpy -> DataFrame
        def to_frame(bunch) -> pd.DataFrame:
            if hasattr(bunch, "frame") and bunch.frame is not None:
                df = bunch.frame
            else:
                import pandas as _pd
                X = bunch.data
                df = _pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
                if hasattr(bunch, "target") and bunch.target is not None:
                    df["target"] = bunch.target
            return df

        loaders = {
            "iris": lambda: to_frame(datasets.load_iris(as_frame=True)),
            "wine": lambda: to_frame(datasets.load_wine(as_frame=True)),
            "breast_cancer": lambda: to_frame(datasets.load_breast_cancer(as_frame=True)),
            "diabetes": lambda: to_frame(datasets.load_diabetes(as_frame=True)),
            # digits 可能存在 as_frame 兼容性差异，做 try 回退
            "digits": lambda: to_frame(
                (datasets.load_digits(as_frame=True) if "as_frame" in datasets.load_digits.__code__.co_varnames else datasets.load_digits())
            ),
            "linnerud": lambda: to_frame(datasets.load_linnerud(as_frame=True)),
            # 加州房价可能需要网络，但 sklearn 会缓存；失败则提示
            "california_housing": lambda: to_frame(datasets.fetch_california_housing(as_frame=True)),
        }
        if name_norm not in loaders:
            raise ValueError("不支持的数据集名称，支持: iris|wine|breast_cancer|diabetes|digits|linnerud|california_housing")
        return loaders[name_norm]()

    def execute(
        self,
        dataset_name: str,
        save_dir: Optional[str] = "workspace",
        format: Optional[str] = "csv",
        file_basename: Optional[str] = None,
    ) -> ToolResult:
        try:
            df = self._load_dataset(dataset_name)
            out_dir = Path(save_dir or "workspace")
            out_dir.mkdir(parents=True, exist_ok=True)
            fmt = (format or "csv").lower()
            if fmt not in {"csv", "parquet"}:
                fmt = "csv"
            base = (file_basename or f"sk_{(dataset_name or '').strip().lower().replace('-', '_')}")
            path = out_dir / (base + (".parquet" if fmt == "parquet" else ".csv"))
            if fmt == "parquet":
                df.to_parquet(path, index=False)
            else:
                df.to_csv(path, index=False)
            return ToolResult(output={
                "dataset": dataset_name,
                "path": str(path),
                "rows": int(df.shape[0]),
                "cols": int(df.shape[1]),
                "format": fmt,
            })
        except Exception as e:
            return ToolFailure(error=f"prepare builtin dataset failed: {e}")


class RunDataScience(BaseTool):
    """封装数据科学 Agent 的工具。

    用于在需要进行数据分析/建模时由大模型主动调用，执行完整的数据科学工作流，并将产出保存到 output/results/{run_id}/ 下。
    """

    name: str = "run_datascience"
    description: str = (
        "当用户需求需要进行数据探索、特征工程、建模训练或绘图分析时，调用该工具来运行数据科学 Agent。"
    )
    input: str = "requirement: 任务需求描述; run_id: 可选，用于结果归档的标识; complexity: auto|simple|complex（可选，默认auto）"
    output: str = "{plan} 其中 plan 为数据科学任务的规划与执行详情(JSON)"
    parameters: dict = {
        "type": "object",
        "properties": {
            "requirement": {"type": "string", "description": "数据科学任务的自然语言描述"},
            "run_id": {"type": "string", "description": "可选运行ID，不提供则自动生成"},
            "complexity": {"type": "string", "enum": ["auto", "simple", "complex"], "description": "任务复杂度：auto|simple|complex，影响是否进行澄清与是否进一步拆分"},
        },
        "required": ["requirement"],
    }

    def execute(self, requirement: str, run_id: Optional[str] = "", complexity: Optional[str] = "auto") -> ToolResult:
        try:
            # 延迟导入，避免与 llm -> registry -> ds_tools 的初始化循环
            from agent.datascience import DSagent  # type: ignore
            agent = DSagent()
            rid = run_id or datetime.now().strftime("%Y%m%d_%H%M%S")
            plan_dict, _ = agent.act(requirement=requirement, run_id=rid, complexity=(complexity or "auto"))

            # 移除 code 字段，减少 token 消耗
            self._remove_code_recursive(plan_dict)

            return ToolResult(output={"plan": plan_dict})
        except Exception as e:
            return ToolFailure(error=f"run datascience failed: {e}")

    def _remove_code_recursive(self, data: Any):
        """递归移除字典中的 code 字段"""
        if isinstance(data, dict):
            if "code" in data:
                del data["code"]
            for value in data.values():
                self._remove_code_recursive(value)
        elif isinstance(data, list):
            for item in data:
                self._remove_code_recursive(item)
        
if __name__ == "__main__":
    # 简单测试 DataDownloader
    dd = DataExplorer()
    res = dd._list_files(Path("workspace\\PVODdatasets_v1.0"), limit=20)
    print(res)
