# -*- coding: utf-8 -*-
"""
Restaurant rubrics definition & rubric generation + validation entry.

1）定义“餐厅相关 rubrics”的模板（人类可读描述 + generate/validate 函数名）
2）在 __main__ 中：
   - 读取 city_trips 文件
   - 针对每个 trip 调用所有 generate_* 函数
   - 记录每个函数的耗时
   - 利用 generate_params 再次调用 generate_* 验证 candidate_product_ids 是否一致
"""

from typing import List, Dict, Any
from val_restaurants import RestaurantEvaluator

# =========================
# Rubric 定义（价格）
# =========================

RUBRIC_PRICE = {

    # ───────────── 单人单顿（per person per meal）─────────────
    "Per-person per-meal cost less than a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost less than {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost equal to or higher than {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost more than a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost more than {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost equal to or lower than {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost around a certain price": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost around {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost far from {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
    "Per-person per-meal cost between a certain price range": {
        "probability": 1,
        "description": "Each selected restaurant has a per-person per-meal cost that falls within {slot}.",
        "violation_description": "Some recommended restaurants have a per-person per-meal cost outside {slot}.",
        "generate_func": "generate_price_per_person_range",
        "validate_func": "validate_price_per_person_range",
    },
}

# =========================
# 其他 rubrics（保持不变）
# =========================

RUBRIC_RATING = {
    "Minimum overall star rating": {
        "probability": 1,
        "description": "Only recommend restaurants with an overall rating of at least {slot} stars.",
        "violation_description": "Some recommended restaurants have an overall rating below {slot} stars.",
        "generate_func": "generate_min_overall_rating",
        "validate_func": "validate_min_overall_rating",
    }
}

RUBRIC_REVIEW_COUNT = {
    "Minimum review count": {
        "probability": 1,
        "description": "Prefer restaurants that have at least {slot} reviews.",
        "violation_description": "Some recommended restaurants have fewer than {slot} reviews.",
        "generate_func": "generate_min_review_count",
        "validate_func": "validate_min_review_count",
    }
}

RUBRIC_INCLUDE_CUISINE = {
    "Include certain cuisines": {
        "probability": 1,
        "description": "Make sure the plan includes restaurants serving these cuisines: {slot}.",
        "violation_description": "The plan does not include enough options for these cuisines: {slot}.",
        "generate_func": "generate_include_cuisines",
        "validate_func": "validate_include_cuisines",
    }
}

RUBRIC_EXCLUDE_CUISINE = {
    "Exclude certain cuisines": {
        "probability": 1,
        "description": "Avoid restaurants that focus on these cuisines: {slot}.",
        "violation_description": "Some recommended restaurants still focus on these cuisines: {slot}.",
        "generate_func": "generate_exclude_cuisines",
        "validate_func": "validate_exclude_cuisines",
    }
}

# ========= Open / Reservable =========
RUBRIC_OPEN = {
    # 1）偏好“可预订”的餐厅
    "Must be reservable": {
        "probability": 1,
        "description": "Whenever possible, choose restaurants that support reservations, so that seats can be secured in advance.",
        "violation_description": "Some selected restaurants cannot be reserved in advance, which may cause difficulty in getting seats.",
        "generate_func": "generate_reservable",
        "validate_func": "validate_reservable",
    },

    # 2）排除“必须预约才能去”的餐厅
    "Exclude must-reservation restaurants": {
        "probability": 1,
        "description": "Exclude restaurants that require mandatory advance reservations.",
        "violation_description": "Some selected restaurants still require mandatory advance reservations.",
        "generate_func": "generate_exclude_must_reserve",
        "validate_func": "validate_exclude_must_reserve",
    },
}

# ========= 三类 subrating 独立约束 =========
RUBRIC_SUBRATING_FOOD = {
    "Minimum food rating": {
        "probability": 1,
        "description": "Prefer restaurants where the food quality rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on food quality rating >= {slot}.",
        "generate_func": "generate_min_food_rating",
        "validate_func": "validate_min_food_rating",
    }
}

RUBRIC_SUBRATING_ENVIRONMENT = {
    "Minimum environment rating": {
        "probability": 1,
        "description": "Prefer restaurants where the environment/ambience rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on environment rating >= {slot}.",
        "generate_func": "generate_min_environment_rating",
        "validate_func": "validate_min_environment_rating",
    }
}

RUBRIC_SUBRATING_SERVICE = {
    "Minimum service rating": {
        "probability": 1,
        "description": "Prefer restaurants where the service rating is at least {slot}.",
        "violation_description": "Some selected restaurants do not satisfy the requirement on service rating >= {slot}.",
        "generate_func": "generate_min_service_rating",
        "validate_func": "validate_min_service_rating",
    }
}


ALL_RUBRICS = [
    RUBRIC_PRICE,
    RUBRIC_RATING,
    RUBRIC_REVIEW_COUNT,
    RUBRIC_INCLUDE_CUISINE,
    RUBRIC_EXCLUDE_CUISINE,
    RUBRIC_OPEN,
    RUBRIC_SUBRATING_FOOD,
    RUBRIC_SUBRATING_ENVIRONMENT,
    RUBRIC_SUBRATING_SERVICE,
]


if __name__ == "__main__":
    import json
    import os
    import time
    from typing import List, Dict, Any
    from collections import defaultdict

    num = 0
    # ---- 路径设置（按你的工程结构改） ----
    base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

    restaurants_path = "    # trip 文件（结构与你给的一致）
    trips_file = "    # ---- 初始化 evaluator ----
    re_eval = RestaurantEvaluator(restaurants_path)

    with open(trips_file, "r", encoding="utf-8") as f:
        trips = json.load(f)

    # 所有 rubrics 的 dict
    RUBRIC_DICTS = {
        "RUBRIC_PRICE": RUBRIC_PRICE,
        "RUBRIC_RATING": RUBRIC_RATING,
        "RUBRIC_REVIEW_COUNT": RUBRIC_REVIEW_COUNT,
        "RUBRIC_INCLUDE_CUISINE": RUBRIC_INCLUDE_CUISINE,
        "RUBRIC_EXCLUDE_CUISINE": RUBRIC_EXCLUDE_CUISINE,
        "RUBRIC_OPEN": RUBRIC_OPEN,
        "RUBRIC_SUBRATING_FOOD": RUBRIC_SUBRATING_FOOD,
        "RUBRIC_SUBRATING_ENVIRONMENT": RUBRIC_SUBRATING_ENVIRONMENT,
        "RUBRIC_SUBRATING_SERVICE": RUBRIC_SUBRATING_SERVICE,
    }

    results: List[Dict[str, Any]] = []
    time_stats: Dict[str, List[float]] = defaultdict(list)

    # 只处理前两个 trip
    # for trip in trips[:2]:
    for trip in trips:
        
        # trip_id = trip["id"]
        # route = trip["route"][0]  # 假设每个 trip 只有一条 route

        # dest_city = route["to"]
        # days = route.get("stay_days", 1)
        # num_people = route.get("number_of_people", 1)

        # city_list = [dest_city]

        trip_id = trip["id"]

        routes = trip["route"]  # 不要只取 [0]

        city_list = []
        for r in routes:
            city_list.append(r["to"])

        days = routes[0].get("stay_days", 1) + routes[1].get("stay_days", 1)
        num_people = routes[0].get("number_of_people", 1)

        # 把人数塞进 anchor_info，供 group-budget 使用
        anchor_info: Dict[str, Any] = {
            "num_people": num_people
        }

        generate_params_list = [city_list, days, anchor_info]

        trip_result: Dict[str, Any] = {
            "trip_id": trip_id,
            "route": trip["route"],
            "rubric_results": {},
        }

        print(f"\n{num}====== Trip {trip_id} ({city_list}, days={days}, people={num_people}) ======")
        num += 1
        # ---- 遍历所有 rubrics（已删掉二次生成验证）----
        for rubric_name, rubric_dict in RUBRIC_DICTS.items():
            trip_result["rubric_results"][rubric_name] = {}

            for key, config in rubric_dict.items():
                gen_name = config["generate_func"]

                first_call_gen_params = {"rubric_key": key}

                # ---- 第一次 generate（唯一一次） ----
                t0 = time.perf_counter()
                try:
                    gen_result = re_eval.execute(
                        gen_name,
                        *generate_params_list,
                        generate_params=first_call_gen_params,
                    )
                except Exception as e:
                    elapsed = time.perf_counter() - t0
                    time_stats[gen_name].append(elapsed)
                    trip_result["rubric_results"][rubric_name][key] = {
                        "generate_func": gen_name,
                        "error": str(e),
                    }
                    print(f"  [ERROR] generate failed: {e} (time={elapsed:.3f}s)")
                    continue

                elapsed = time.perf_counter() - t0
                time_stats[gen_name].append(elapsed)

                # --- 去掉 candidate 列表，只保留参数与说明 ---
                filtered_result = dict(gen_result)
                filtered_result.pop("candidate_ids", None)
                filtered_result.pop("candidate_product_ids", None)

                trip_result["rubric_results"][rubric_name][key] = {
                    "probability": config["probability"],
                    "description": config["description"],
                    "violation_description": config["violation_description"],
                    "generate_func": gen_name,
                    "validate_func": config["validate_func"],
                    "result": filtered_result,
                }

        results.append(trip_result)

    # ---- 输出文件 ----
    output_file = "    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"\n处理完成！结果已保存到 {output_file}")

