import os
from pathlib import Path
from typing import List, Optional, Dict, Any

from .base import BaseTool, ToolResult, ToolFailure
import zipfile


def _has_kaggle_credentials() -> bool:
    # 1) ~/.kaggle/kaggle.json
    try:
        kaggle_json = Path.home() / ".kaggle" / "kaggle.json"
        if kaggle_json.exists() and kaggle_json.read_text(encoding="utf-8").strip():
            return True
    except Exception:
        pass
    # 2) env
    return bool(os.getenv("KAGGLE_USERNAME") and os.getenv("KAGGLE_KEY"))


def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


class KaggleDownloader(BaseTool):
    """
    通过 Kaggle API 下载数据集或竞赛数据。
    """

    name: str = "kaggle_downloader"
    description: str = (
        "使用 Kaggle 官方 API 下载kaggle上的数据集或竞赛数据；"
    )
    input: str = (
        "target_type: 'dataset'|'competition'; id_or_slug: 数据集如 'zynicide/wine-reviews' 或竞赛如 'titanic'; "
        "download_dir: 保存目录; unzip?: 是否解压; file?: 指定单个文件"
    )
    output: str = "{downloaded: [paths], details: {dir, target_type, id_or_slug, unzip, file}}"
    parameters: dict = {
        "type": "object",
        "properties": {
            "target_type": {
                "type": "string",
                "enum": ["dataset", "competition", "auto"],
                "description": "下载类型：数据集/竞赛/自动判断（auto: 含'/'视为数据集，否则视为竞赛）",
            },
            "id_or_slug": {
                "type": "string",
                "description": "数据集: 'owner/dataset'；竞赛: 'titanic' 等",
            },
            "download_dir": {"type": "string", "description": "保存目录(默认 workspace/data)"},
            "unzip": {"type": "boolean", "description": "是否解压（默认 True）"},
            "file": {"type": "string", "description": "可选：仅下载特定文件名"},
        },
        "required": ["id_or_slug"],
    }

    examples: List[str] = [
        "下载数据集：target_type='dataset', id_or_slug='zynicide/wine-reviews', download_dir='workspace/data', unzip=True",
        "下载竞赛：target_type='competition', id_or_slug='titanic', download_dir='workspace/data', unzip=True",
    ]

    def execute(
        self,
        target_type: str = "auto",
        id_or_slug: str = "",
        download_dir: str = "workspace/data",
        unzip: Optional[bool] = True,
        file: Optional[str] = None,
    ) -> ToolResult:
        # 检查凭证
        if not _has_kaggle_credentials():
            return ToolFailure(
                error=(
                    "kaggle credentials not found. Configure ~/.kaggle/kaggle.json or set KAGGLE_USERNAME/KAGGLE_KEY."
                )
            )

        try:
            from kaggle.api.kaggle_api_extended import KaggleApi
        except Exception as e:
            return ToolFailure(error=f"kaggle package not installed or failed to import: {e}")

        try:
            api = KaggleApi()
            api.authenticate()
        except Exception as e:
            return ToolFailure(error=f"kaggle authenticate failed: {e}")

        dst = Path(download_dir)
        _ensure_dir(dst)

        # 记录下载前后的文件差集
        before = {p.resolve() for p in dst.rglob('*') if p.is_file()}

        # 自动判定与回退策略
        primary: str
        fallback: Optional[str] = None
        if target_type == "auto":
            primary = "dataset" if ("/" in id_or_slug) else "competition"
            fallback = "competition" if primary == "dataset" else "dataset"
        else:
            primary = target_type
            if target_type == "dataset" and "/" not in id_or_slug:
                # 明确提示：dataset 需要 owner/dataset 形式
                return ToolFailure(
                    error=f"invalid dataset id_or_slug='{id_or_slug}'. Use 'owner/dataset' (e.g. 'zynicide/wine-reviews') "
                          f"or set target_type='competition' for '{id_or_slug}'."
                )

        def _do_download(kind: str):
            if kind == "dataset":
                if file:
                    api.dataset_download_file(id_or_slug, file, path=str(dst), force=True)
                else:
                    api.dataset_download_files(id_or_slug, path=str(dst), unzip=bool(unzip), force=True)
            elif kind == "competition":
                if file:
                    api.competition_download_file(id_or_slug, file, path=str(dst), force=True)
                else:
                    api.competition_download_files(id_or_slug, path=str(dst), force=True)
            else:
                raise ValueError(f"unsupported target_type: {kind}")

        try:
            _do_download(primary)
        except Exception as e1:
            # 403/权限、404/不存在、或类型错误时考虑回退
            if fallback is not None:
                try:
                    _do_download(fallback)
                    primary = fallback  # 实际成功的类型
                except Exception as e2:
                    return ToolFailure(error=f"kaggle download failed: primary={primary} err={e1}; fallback={fallback} err={e2}")
            else:
                # 提示常见 403 处理
                msg = str(e1)
                if "403" in msg or "Permission 'datasets.get'" in msg or "Permission 'competitions.get'" in msg:
                    return ToolFailure(error=(
                        "kaggle download failed (403). 请检查：1) 已配置 ~/.kaggle/kaggle.json 或 KAGGLE_USERNAME/KAGGLE_KEY；"
                        "2) 已在网页端 Join/Accept 对应数据集/竞赛规则；3) id_or_slug 是否正确。原始错误：" + msg
                    ))
                return ToolFailure(error=f"kaggle download failed: {msg}")

        after = {p.resolve() for p in dst.rglob('*') if p.is_file()}
        added = sorted(after - before)
        workspace_root = Path("workspace").resolve()

        # 如需解压：对新增 zip 文件执行解压（competitions 无 unzip 参数时生效）
        extracted: list[str] = []
        if unzip:
            for zp in added:
                if zp.suffix.lower() == ".zip":
                    try:
                        with zipfile.ZipFile(zp, 'r') as zf:
                            zf.extractall(dst)
                        now_after = {p.resolve() for p in dst.rglob('*') if p.is_file()}
                        newly = sorted(now_after - after)
                        extracted.extend([p for p in newly])
                        after = now_after
                    except Exception:
                        pass

        combined_added = added + extracted
        new_files = [str(p.relative_to(workspace_root)) for p in combined_added]
        
        rel_dst = dst.resolve().relative_to(workspace_root)
        return ToolResult(
            output={
                "downloaded": new_files,
                "details": {
                    "dir": str(rel_dst),
                    "target_type": primary,
                    "id_or_slug": id_or_slug,
                    "unzip": bool(unzip),
                    "file": file,
                },
            }
        )
