# -*- coding: utf-8 -*-
"""
Restaurant rubrics definition & rubric generation + validation entry.

1）定义“餐厅相关 rubrics”的模板（人类可读描述 + generate/validate 函数名）
2）在 __main__ 中：
   - 读取 city_trips 文件
   - 针对每个 trip 调用所有 generate_* 函数
   - 记录每个函数的耗时
   - （你当前版本已删掉二次生成验证）
"""

from typing import List, Dict, Any
from collections import defaultdict
import time
import json
import os
from multiprocessing import Pool

from val_restaurants import RestaurantEvaluator

# =========================
# Rubric 定义（价格）
# =========================

RUBRIC_PRICE = {

    # ───────────── 单人单顿（per person per meal）─────────────
    "Per-person per-meal cost less than a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost less than {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost equal to or higher than {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost more than a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost more than {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost equal to or lower than {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost around a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost around {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost far from {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost between a certain price range": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost that falls within {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost outside {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
}

# =========================
# 其他 rubrics（保持不变）
# =========================

RUBRIC_RATING = {
    "Minimum overall star rating": {
        "probability": 1,
        "description": "Only recommend restaurants with an overall rating of at least {slot} stars.",
        "violation_description": "Some recommended restaurants have an overall rating below {slot} stars.",
        "generate_func": "generate_min_overall_rating",
        "validate_func": "validate_min_overall_rating",
    }
}

RUBRIC_REVIEW_COUNT = {
    "Minimum review count": {
        "probability": 1,
        "description": "Prefer restaurants that have at least {slot} reviews.",
        "violation_description": "Some recommended restaurants have fewer than {slot} reviews.",
        "generate_func": "generate_min_review_count",
        "validate_func": "validate_min_review_count",
    }
}

RUBRIC_INCLUDE_CUISINE = {
    "Include certain cuisines": {
        "probability": 1,
        "description": "Make sure the plan includes restaurants serving these cuisines: {slot}.",
        "violation_description": "The plan does not include enough options for these cuisines: {slot}.",
        "generate_func": "generate_include_cuisines",
        "validate_func": "validate_include_cuisines",
    }
}

RUBRIC_EXCLUDE_CUISINE = {
    "Exclude certain cuisines": {
        "probability": 1,
        "description": "Avoid restaurants that focus on these cuisines: {slot}.",
        "violation_description": "Some recommended restaurants still focus on these cuisines: {slot}.",
        "generate_func": "generate_exclude_cuisines",
        "validate_func": "validate_exclude_cuisines",
    }
}

# ========= Open / Reservable =========
RUBRIC_OPEN = {
    # 1）偏好“可预订”的餐厅
    "Must be reservable": {
        "probability": 1,
        "description": "Whenever possible, choose restaurants that support reservations, so that seats can be secured in advance.",
        "violation_description": "Some selected restaurants cannot be reserved in advance, which may cause difficulty in getting seats.",
        "generate_func": "generate_reservable",
        "validate_func": "validate_reservable",
    },

    # 2）排除“必须预约才能去”的餐厅
    "Exclude must-reservation restaurants": {
        "probability": 1,
        "description": "Exclude restaurants that require mandatory advance reservations.",
        "violation_description": "Some selected restaurants still require mandatory advance reservations.",
        "generate_func": "generate_exclude_must_reserve",
        "validate_func": "validate_exclude_must_reserve",
    },
}

# ========= 三类 subrating 独立约束 =========
RUBRIC_SUBRATING_FOOD = {
    "Minimum food rating": {
        "probability": 1,
        "description": "Prefer restaurants where the food quality rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on food quality rating >= {slot}.",
        "generate_func": "generate_min_food_rating",
        "validate_func": "validate_min_food_rating",
    }
}

RUBRIC_SUBRATING_ENVIRONMENT = {
    "Minimum environment rating": {
        "probability": 1,
        "description": "Prefer restaurants where the environment/ambience rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on environment rating >= {slot}.",
        "generate_func": "generate_min_environment_rating",
        "validate_func": "validate_min_environment_rating",
    }
}

RUBRIC_SUBRATING_SERVICE = {
    "Minimum service rating": {
        "probability": 1,
        "description": "Prefer restaurants where the service rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on service rating >= {slot}.",
        "generate_func": "generate_min_service_rating",
        "validate_func": "validate_min_service_rating",
    }
}


ALL_RUBRICS = [
    RUBRIC_PRICE,
    RUBRIC_RATING,
    RUBRIC_REVIEW_COUNT,
    RUBRIC_INCLUDE_CUISINE,
    RUBRIC_EXCLUDE_CUISINE,
    RUBRIC_OPEN,
    RUBRIC_SUBRATING_FOOD,
    RUBRIC_SUBRATING_ENVIRONMENT,
    RUBRIC_SUBRATING_SERVICE,
]

# 所有 rubrics 的 dict（保持你原来的结构）
RUBRIC_DICTS = {
    "RUBRIC_PRICE": RUBRIC_PRICE,
    "RUBRIC_RATING": RUBRIC_RATING,
    "RUBRIC_REVIEW_COUNT": RUBRIC_REVIEW_COUNT,
    "RUBRIC_INCLUDE_CUISINE": RUBRIC_INCLUDE_CUISINE,
    "RUBRIC_EXCLUDE_CUISINE": RUBRIC_EXCLUDE_CUISINE,
    "RUBRIC_OPEN": RUBRIC_OPEN,
    "RUBRIC_SUBRATING_FOOD": RUBRIC_SUBRATING_FOOD,
    "RUBRIC_SUBRATING_ENVIRONMENT": RUBRIC_SUBRATING_ENVIRONMENT,
    "RUBRIC_SUBRATING_SERVICE": RUBRIC_SUBRATING_SERVICE,
}

# =========================
# 多进程：worker 全局 evaluator
# =========================
re_eval = None


def init_worker(restaurants_path: str):
    global re_eval
    re_eval = RestaurantEvaluator(restaurants_path)


def process_trip(trip: Dict[str, Any]):
    """
    处理单个 trip（保持原逻辑）：
    - 执行所有 rubrics 的 generate
    - 记录耗时
    返回：trip_result, local_time_stats
    """
    global re_eval

    trip_id = trip["id"]
    routes = trip["route"]  # 不要只取 [0]

    city_list = []
    for r in routes:
        city_list.append(r["to"])

    # 保持你原来的 days 计算方式（固定取 routes[0] + routes[1]）
    days = routes[0].get("stay_days", 1) + routes[1].get("stay_days", 1)
    num_people = routes[0].get("number_of_people", 1)

    # 把人数塞进 anchor_info，供 group-budget 使用
    anchor_info: Dict[str, Any] = {"num_people": num_people}
    generate_params_list = [city_list, days, anchor_info]

    trip_result: Dict[str, Any] = {
        "trip_id": trip_id,
        "route": trip["route"],
        "rubric_results": {},
    }

    local_time_stats: Dict[str, List[float]] = defaultdict(list)

    # ---- 遍历所有 rubrics（已删掉二次生成验证）----
    for rubric_name, rubric_dict in RUBRIC_DICTS.items():
        trip_result["rubric_results"][rubric_name] = {}

        for key, config in rubric_dict.items():
            gen_name = config["generate_func"]
            first_call_gen_params = {"rubric_key": key}

            # ---- 第一次 generate（唯一一次） ----
            t0 = time.perf_counter()
            try:
                gen_result = re_eval.execute(
                    gen_name,
                    *generate_params_list,
                    generate_params=first_call_gen_params,
                )
            except Exception as e:
                elapsed = time.perf_counter() - t0
                local_time_stats[gen_name].append(elapsed)
                trip_result["rubric_results"][rubric_name][key] = {
                    "generate_func": gen_name,
                    "error": str(e),
                }
                continue

            elapsed = time.perf_counter() - t0
            local_time_stats[gen_name].append(elapsed)

            # --- 去掉 candidate 列表，只保留参数与说明 ---
            filtered_result = dict(gen_result)
            filtered_result.pop("candidate_ids", None)
            filtered_result.pop("candidate_product_ids", None)

            trip_result["rubric_results"][rubric_name][key] = {
                "probability": config["probability"],
                "description": config["description"],
                "violation_description": config["violation_description"],
                "generate_func": gen_name,
                "validate_func": config["validate_func"],
                "result": filtered_result,
            }

    return trip_result, dict(local_time_stats)


if __name__ == "__main__":
    num = 0

    # ---- 路径设置（按你的工程结构改） ----
    base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

    restaurants_path = "    # trip 文件（结构与你给的一致）
    trips_file = "    with open(trips_file, "r", encoding="utf-8") as f:
        trips = json.load(f)

    results: List[Dict[str, Any]] = []
    time_stats: Dict[str, List[float]] = defaultdict(list)

    total = len(trips)
    print(f"共需处理 {total} 条 trip\n")

    num_workers = 32

    # ---- 输出改为 JSONL ----
    output_file = "    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # === 多进程并行 + 写 JSONL + 显示进度 ===
    with Pool(processes=num_workers, initializer=init_worker, initargs=(restaurants_path,)) as pool, \
            open(output_file, "w", encoding="utf-8") as fout:

        for i, (trip_result, local_time) in enumerate(pool.imap_unordered(process_trip, trips), start=1):
            # 写 JSONL
            fout.write(json.dumps(trip_result, ensure_ascii=False) + "\n")

            # 合并耗时统计（保持原逻辑：按函数名记录 list）
            for fn, arr in local_time.items():
                time_stats[fn].extend(arr)

            print(f"完成 {i}/{total}  ({i/total:.1%})")

    print(f"\n处理完成！结果已保存到 {output_file}")
