from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pathlib import Path
from typing import Optional, Dict, Any, Set
import threading
import time
import uuid
import json
import subprocess
import sys
import asyncio
import logger
from knowledgeBase.knowledge import ExpertFeedbackStore
import re
from collections import defaultdict
import os
from contextlib import asynccontextmanager

# WebSocket 连接池（按 run_id 分组）
RUN_CLIENTS: Dict[str, Set[WebSocket]] = defaultdict(set)
RUN_CLIENTS_LOCK = threading.Lock()

# 主事件循环（用于从线程安全地调度发送）
EVENT_LOOP: Optional[asyncio.AbstractEventLoop] = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    global EVENT_LOOP
    EVENT_LOOP = asyncio.get_running_loop()
    yield
    # Shutdown (如果需要清理资源)
    pass

app = FastAPI(title="autoDS API", version="0.1.0", lifespan=lifespan)

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static frontend under /app to avoid intercepting /api/*
frontend_dir = Path("frontend")
if frontend_dir.exists():
    app.mount("/app", StaticFiles(directory=str(frontend_dir), html=True), name="frontend")

# Expose generated run results (images/notebooks) under /results/<run_id>/
results_dir = Path("output") / "results"
results_dir.mkdir(parents=True, exist_ok=True)
app.mount("/results", StaticFiles(directory=str(results_dir)), name="results")

@app.get("/")
def root_index():
    # Redirect root to /app/
    return RedirectResponse(url="/app/")

RUNS: Dict[str, Dict[str, Any]] = {}
RUNS_LOCK = threading.Lock()

# 定义一个颜色映射，匹配前端的着色规则
LEVEL_COLOR_MAP = {
    "TRACE": "white",
    "DEBUG": "grey",
    "INFO": "blue",
    "WARNING": "yellow",
    "ERROR": "red",
    "SUCCESS": "green",
    "SPECIAL": "cyan",
}


# 从日志文件中按行追踪（tail -f），并按“时间戳 + [LEVEL]”进行聚合后推送到前端
def _tail_log_file(log_path: Path, rid: str):
    try:
        start_pattern = re.compile(r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}) \[([A-Z]+)\]')
        ansi_escape = re.compile(r'\x1b\[[0-9;]*[A-Za-z]')

        # 等待日志文件创建
        while not log_path.exists():
            # 如果进程已结束且文件还没出现，则直接返回
            with RUNS_LOCK:
                info = RUNS.get(rid)
                proc = info.get("process") if info else None
            if proc is not None and proc.poll() is not None:
                return
            time.sleep(0.1)

        buffer_lines = []
        current_level = None
        last_activity = time.time()

        with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
            # 从文件开头开始读，确保不丢任何一行
            while True:
                line = f.readline()
                if not line:
                    # 空读，尝试超时刷新，检查进程状态
                    now = time.time()
                    if buffer_lines and (now - last_activity) > 1.0:
                        combined = "\n".join(buffer_lines)
                        if current_level and current_level != "TRACE":
                            color = LEVEL_COLOR_MAP.get(current_level, "white")
                            _broadcast_run_log(rid, {"level": current_level, "color": color, "message": combined})
                        buffer_lines = []
                        current_level = None

                    with RUNS_LOCK:
                        info = RUNS.get(rid)
                        proc = info.get("process") if info else None
                    # 如果进程已结束且没有更多内容，则 flush 并退出
                    if proc is not None and proc.poll() is not None:
                        if buffer_lines and current_level and current_level != "TRACE":
                            combined = "\n".join(buffer_lines)
                            color = LEVEL_COLOR_MAP.get(current_level, "white")
                            _broadcast_run_log(rid, {"level": current_level, "color": color, "message": combined})
                        break

                    time.sleep(0.1)
                    continue

                text = line.rstrip('\n')
                if not text:
                    continue
                text = ansi_escape.sub('', text)

                m = start_pattern.match(text)
                if m:
                    # 新的日志条目开始，先冲刷旧的
                    if buffer_lines:
                        if current_level and current_level != "TRACE":
                            combined = "\n".join(buffer_lines)
                            color = LEVEL_COLOR_MAP.get(current_level, "white")
                            _broadcast_run_log(rid, {"level": current_level, "color": color, "message": combined})
                        buffer_lines = []
                        current_level = None

                    current_level = m.group(2)
                    buffer_lines = [text]
                    last_activity = time.time()
                else:
                    # 多行内容，继续累积
                    if buffer_lines:
                        buffer_lines.append(text)
                        last_activity = time.time()
                    else:
                        # 保险起见：若文件开头就不是标准行，也做成一个 INFO 块
                        current_level = "INFO"
                        buffer_lines = [text]
                        last_activity = time.time()

        # 文件循环结束，确保最后一次 flush
        if buffer_lines and current_level and current_level != "TRACE":
            combined = "\n".join(buffer_lines)
            color = LEVEL_COLOR_MAP.get(current_level, "white")
            _broadcast_run_log(rid, {"level": current_level, "color": color, "message": combined})
    except Exception as e:
        print(f"[_tail_log_file] error: {e}", file=sys.stderr)



def _broadcast_run_log(run_id: str, payload: Dict[str, Any]):
    """将单条日志推送给该 run 的所有 WebSocket 客户端（从子进程 stdout/stderr 实时转发）。"""
    if EVENT_LOOP is None:
        return
    with RUN_CLIENTS_LOCK:
        clients = list(RUN_CLIENTS.get(run_id, set()))
    for ws in clients:
        try:
            asyncio.run_coroutine_threadsafe(ws.send_json(payload), EVENT_LOOP)
        except Exception:
            # 连接可能已失效，移除
            with RUN_CLIENTS_LOCK:
                RUN_CLIENTS.get(run_id, set()).discard(ws)


class RunRequest(BaseModel):
    requirement: str

class ExtractRequest(BaseModel):
    plan_path: str
    task_type: Optional[str] = None

@app.get("/api/health")
def health():
    return {"status": "ok"}


@app.get("/api/results")
def list_results():
    base = Path("output/results")
    items = []
    if base.exists():
        for d in sorted(base.iterdir(), key=lambda p: p.name, reverse=True):
            if d.is_dir():
                plan = d / "plan.json"
                ipynb = d / "code.ipynb"
                items.append({
                    "dir": str(d),
                    "plan": str(plan) if plan.exists() else None,
                    "notebook": str(ipynb) if ipynb.exists() else None,
                })
    return {"results": items}

@app.post("/api/run")
def start_run(req: RunRequest, background_tasks: BackgroundTasks):
        run_id = time.strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]

        # 启动独立进程运行 run_worker.py，使用 -u 参数让 Python 不缓冲输出
        args = [sys.executable, "-u", str(Path(__file__).parent / "run_worker.py"), run_id, req.requirement]
        # 在 Windows 下强制子进程使用 UTF-8，避免 GBK 无法编码 ✓/✗ 等字符
        child_env = os.environ.copy() if hasattr(os, 'environ') else None
        if child_env is not None:
            child_env["PYTHONIOENCODING"] = "utf-8"
            child_env["PYTHONUTF8"] = "1"
        try:
            # 不再从 stdout 流读取日志，直接将其丢弃，统一改为从日志文件 tail
            proc = subprocess.Popen(
                args,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                bufsize=1,
                text=True,
                encoding="utf-8",
                errors="replace",
                env=child_env,
            )
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"failed to start run worker: {e}")

        log_path = Path('output') / 'logs' / f'run_{run_id}.log'
        with RUNS_LOCK:
            RUNS[run_id] = {
                "status": "running",
                "created_at": time.time(),
                "started_at": time.time(),
                "requirement": req.requirement,
                "pid": proc.pid,
                "process": proc,
                "log_path": str(log_path),
            }

        # 启动日志文件 tail 线程
        threading.Thread(target=_tail_log_file, args=(log_path, run_id), daemon=True).start()

        return {"run_id": run_id}

@app.post("/api/extract-experience")
def extract_experience(req: ExtractRequest):
    plan_path = Path(req.plan_path)
    if not plan_path.exists():
        raise HTTPException(status_code=400, detail="plan_path not found")
    store = ExpertFeedbackStore()
    # 在此不单独 init 日志，沿用最近日志
    store.extract_from_success_case(str(plan_path), task_type=req.task_type)
    return {"status": "ok"}


class NotifyFileRequest(BaseModel):
    filename: str


@app.post("/api/runs/{run_id}/notify_file")
def notify_run_file(run_id: str, body: NotifyFileRequest):
    """Called by worker processes to notify the server that a new file
    (image/notebook) has been produced for run_id. The server will broadcast
    a message to connected websocket clients so the frontend can load it.
    """
    base = Path("output") / "results" / run_id
    if not base.exists():
        raise HTTPException(status_code=404, detail="run results dir not found")

    target = base / body.filename
    if not target.exists():
        raise HTTPException(status_code=404, detail="file not found")

    # broadcast to websocket clients for this run
    payload = {
        "type": "file",
        "subtype": "image" if target.suffix.lower() in [".png", ".jpg", ".jpeg", ".gif", ".svg"] else "file",
        "url": f"/results/{run_id}/{body.filename}",
        "name": body.filename,
        "size": target.stat().st_size,
    }
    _broadcast_run_log(run_id, payload)
    return {"status": "ok"}


@app.get("/api/runs/{run_id}/logs")
def get_run_logs(run_id: str):
    """Return the entire log file content for a run as text."""
    with RUNS_LOCK:
        info = RUNS.get(run_id)
    # fallback to default path convention
    if info and info.get("log_path"):
        log_path = Path(info.get("log_path"))
    else:
        log_path = Path('output') / 'logs' / f'run_{run_id}.log'

    if not log_path.exists():
        raise HTTPException(status_code=404, detail="log file not found")

    try:
        text = log_path.read_text(encoding='utf-8', errors='replace')
        return {"text": text}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/runs/{run_id}/files")
def list_run_files(run_id: str):
    """Return list of files produced for a run under output/results/{run_id}.
    Each entry: { name, url, size }
    """
    base = Path('output') / 'results' / run_id
    if not base.exists() or not base.is_dir():
        raise HTTPException(status_code=404, detail="run results dir not found")

    files = []
    try:
        for p in sorted(base.iterdir(), key=lambda x: x.name):
            if p.is_file():
                files.append({
                    "name": p.name,
                    "url": f"/results/{run_id}/{p.name}",
                    "size": p.stat().st_size,
                })
        return {"files": files}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/results/{folder}/files")
def list_results_folder_files(folder: str):
    """List files under output/results/{folder} for the Results browser UI."""
    base = Path('output') / 'results' / folder
    if not base.exists() or not base.is_dir():
        raise HTTPException(status_code=404, detail="results folder not found")

    resp = []
    try:
        for p in sorted(base.iterdir(), key=lambda x: x.name):
            if p.is_file():
                resp.append({
                    "name": p.name,
                    "url": f"/results/{folder}/{p.name}",
                    "size": p.stat().st_size,
                })
        return {"files": resp}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/api/knowledgeBase/summary")
def knowledge_summary():
    """获取经验知识系统汇总数据（外部知识库、Kaggle案例、专家经验）"""
    import json
    from pathlib import Path
    
    store = ExpertFeedbackStore()
    pack = store.load()
    
    # 按来源分类
    external = []
    kaggle = []
    expert = []
    
    for item in pack.items:
        data = {
            "id": item.source_id or "unknown",
            "task_type": item.task_type,
            "recommendation": item.recommendation,
            "rationale": item.rationale or "",
            "score": item.score or 0.5,
            "tags": item.tags or [],
            "source": item.source.value if item.source else "unknown",
        }
        
        if item.source and item.source.value == "external_knowledge":
            external.append(data)
        elif item.source and item.source.value == "success_case":
            expert.append(data)
    
    # 从 case_library.jsonl 加载 Kaggle 案例
    kaggle_cases_path = Path("knowledgeBase") / "case_library.jsonl"
    if kaggle_cases_path.exists():
        try:
            with open(kaggle_cases_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        case = json.loads(line)
                        kaggle.append({
                            "id": case.get("title", "Unknown"),
                            "task_type": case.get("title", "Kaggle Competition"),
                            "recommendation": case.get("description", "")[:200],
                            "rationale": case.get("solutions", {}).get("Top Solutions Summary", "")[:300],
                            "score": 0.8,
                            "tags": case.get("tags", []),
                            "source": "kaggle"
                        })
        except Exception as e:
            logger.error(f"Failed to load Kaggle cases: {e}")
    
    return {
        "external_knowledge": external[:10],
        "kaggle_cases": kaggle[:10],
        "expert_feedback": expert[:10],
        "total": {
            "external": len(external),
            "kaggle": len(kaggle),
            "expert": len(expert),
        }
    }

@app.get("/api/feedback/summary")
def feedback_summary():
    store = ExpertFeedbackStore()
    return store.get_feedback_summary()


@app.get("/api/runs")
def list_runs():
    """列出所有运行记录"""
    now = time.time()
    with RUNS_LOCK:
        runs_list = []
        for run_id, info in RUNS.items():
            # 同步进程状态
            proc = info.get("process")
            status = info.get("status", "unknown")
            if proc is not None:
                code = proc.poll()
                if code is None:
                    status = "running"
                else:
                    status = "completed" if code == 0 else "failed"
                    info["finished_at"] = info.get("finished_at") or now
                info["status"] = status

            started_at = info.get("started_at") or info.get("created_at") or now
            finished_at = info.get("finished_at")
            if status == "running":
                elapsed = now - started_at
            else:
                elapsed = (finished_at or now) - started_at

            runs_list.append({
                "run_id": run_id,
                "status": status,
                "created_at": info.get("created_at"),
                "started_at": started_at,
                "finished_at": finished_at,
                "elapsed_seconds": elapsed,
                "requirement": info.get("requirement", ""),
                "pid": info.get("pid"),
            })

        runs_list.sort(key=lambda x: x.get("created_at") or 0, reverse=True)
        return {"runs": runs_list}


@app.post("/api/runs/{run_id}/terminate")
def terminate_run(run_id: str):
    """终止指定的运行"""
    with RUNS_LOCK:
        info = RUNS.get(run_id)
        if not info:
            raise HTTPException(status_code=404, detail="run not found")
        proc = info.get("process")
        if proc is None:
            raise HTTPException(status_code=400, detail="no process bound to run")

        try:
            # 尝试优雅终止
            proc.terminate()
            try:
                proc.wait(timeout=5)
            except subprocess.TimeoutExpired:
                proc.kill()
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"terminate failed: {e}")

        info["status"] = "terminated"
        info["finished_at"] = time.time()
        return {"message": "terminated", "run_id": run_id}


@app.post("/api/runs/{run_id}/restart")
def restart_run(run_id: str):
    """重启指定的运行（使用相同的 requirement）"""
    with RUNS_LOCK:
        old = RUNS.get(run_id)
        if not old:
            raise HTTPException(status_code=404, detail="run not found")
        requirement = old.get("requirement", "")
        # 如仍在运行先终止
        proc = old.get("process")
        if proc is not None and proc.poll() is None:
            try:
                proc.terminate()
                try:
                    proc.wait(timeout=3)
                except subprocess.TimeoutExpired:
                    proc.kill()
            except Exception:
                pass

    # 启动新 run
    new_run_id = time.strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]
    args = [sys.executable, "-u", str(Path(__file__).parent / "run_worker.py"), new_run_id, requirement]
    # 在 Windows 下强制子进程使用 UTF-8，避免 GBK 无法编码 ✓/✗ 等字符
    child_env = os.environ.copy() if hasattr(os, 'environ') else None
    if child_env is not None:
        child_env["PYTHONIOENCODING"] = "utf-8"
        child_env["PYTHONUTF8"] = "1"
    try:
        new_proc = subprocess.Popen(
            args,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            bufsize=1,
            text=True,
            encoding="utf-8",
            errors="replace",
            env=child_env,
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"failed to restart: {e}")

    log_path = Path('output') / 'logs' / f'run_{new_run_id}.log'
    with RUNS_LOCK:
        RUNS[new_run_id] = {
            "status": "running",
            "created_at": time.time(),
            "started_at": time.time(),
            "requirement": requirement,
            "pid": new_proc.pid,
            "process": new_proc,
            "log_path": str(log_path),
        }

    # 启动日志文件 tail 线程
    threading.Thread(target=_tail_log_file, args=(log_path, new_run_id), daemon=True).start()

    # 前端期望字段名为 run_id
    return {"message": "restarted", "run_id": new_run_id, "old_run_id": run_id}


@app.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket, run_id: Optional[str] = None):
    """WebSocket 端点：实时推送日志到前端。
    - 如果带 run_id：注册到该 run 的推送组，实时接收来自子进程的终端输出（非 TRACE）。
    - 如果不带 run_id：不支持，返回错误提示。
    """
    await websocket.accept()

    if not run_id:
        await websocket.send_json({"error": "run_id is required for live logs"})
        await websocket.close()
        return

    # 注册到该 run 的连接池
    with RUN_CLIENTS_LOCK:
        RUN_CLIENTS[run_id].add(websocket)

    # 保持连接存活，实际发送由后台线程调度
    try:
        while True:
            await asyncio.sleep(60)
    except WebSocketDisconnect:
        pass
    finally:
        with RUN_CLIENTS_LOCK:
            RUN_CLIENTS.get(run_id, set()).discard(websocket)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)
