# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ============ 在 import torch 之前设置 CUDA 环境 ============
# Ray worker 进程可能无法正确继承 CUDA_VISIBLE_DEVICES，需要在这里强制设置
import os
if not os.environ.get("CUDA_VISIBLE_DEVICES"):
    # 如果 CUDA_VISIBLE_DEVICES 未设置，默认使用 GPU 0
    # 这确保 torch 在导入时能够检测到 GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import regex as re
from typing import Dict, List, Optional, Tuple
import json
from mathruler.grader import extract_boxed_content, grade_answer
import time
import random
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed

from collections import Counter
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import torch

STORAGE_PATH = os.getenv("STORAGE_PATH")
# 从环境变量读取 vLLM 服务数量
# 时分复用模式下，vLLM服务数量等于总GPU数量
VLLM_GPU_COUNT = int(os.getenv("VLLM_GPU_COUNT", "8"))

# ============ Memory Bank 相关配置 ============
# Memory Bank 路径，根据 MODEL_ABBR 区分不同实验
MODEL_ABBR = os.getenv("MODEL_ABBR", "")
if MODEL_ABBR:
    MEMORY_BANK_PATH = os.path.join(STORAGE_PATH, "memory_bank", MODEL_ABBR)
else:
    # 向后兼容：如果没有设置 MODEL_ABBR，使用原路径
    MEMORY_BANK_PATH = os.path.join(STORAGE_PATH, "memory_bank")
# Memory-based Penalty 超参数
MEMORY_PENALTY_THRESHOLD = float(os.getenv("MEMORY_PENALTY_THRESHOLD", "0.5"))  # max_similarity 阈值
MEMORY_PENALTY_MEAN_THRESHOLD = float(os.getenv("MEMORY_PENALTY_MEAN_THRESHOLD", "0.3"))  # mean_similarity 阈值
MEMORY_PENALTY_GAMMA = float(os.getenv("MEMORY_PENALTY_GAMMA", "1.0"))  # max_penalty 权重 (1.0=只用max, 0.0=只用mean)
PENALTY_ALPHA = float(os.getenv("PENALTY_ALPHA", "1.0"))  # batch penalty 权重
PENALTY_BETA = float(os.getenv("PENALTY_BETA", "1.0"))   # memory penalty 权重

# ============ 新增：Embedding Type 配置 ============
# EMBEDDING_TYPE: "nl" (natural language) 或 "code" (Python code)
EMBEDDING_TYPE = os.getenv("EMBEDDING_TYPE", "nl")
# KEEP_BATCH_PENALTY_UNCHANGED: 是否保持batch penalty计算方式不变（即使使用code模式）
KEEP_BATCH_PENALTY_UNCHANGED = os.getenv("KEEP_BATCH_PENALTY_UNCHANGED", "false").lower() == "true"
# 嵌入模型配置
NL_EMBEDDING_MODEL = os.getenv("NL_EMBEDDING_MODEL", "BAAI/bge-large-en-v1.5")
CODE_EMBEDDING_MODEL = os.getenv("CODE_EMBEDDING_MODEL", "jinaai/jina-code-embeddings-1.5b")
# 代码生成服务配置
# 时分复用模式下，Code vLLM服务数量等于总GPU数量
CODE_VLLM_GPU_COUNT = int(os.getenv("CODE_VLLM_GPU_COUNT", "8"))
CODE_VLLM_BASE_PORT = int(os.getenv("CODE_VLLM_BASE_PORT", "6000"))

print(f"[caller_penalty] Configuration:")
print(f"  MODEL_ABBR: {MODEL_ABBR}")
print(f"  MEMORY_BANK_PATH: {MEMORY_BANK_PATH}")
print(f"  EMBEDDING_TYPE: {EMBEDDING_TYPE}")
print(f"  KEEP_BATCH_PENALTY_UNCHANGED: {KEEP_BATCH_PENALTY_UNCHANGED}")
print(f"  NL_EMBEDDING_MODEL: {NL_EMBEDDING_MODEL}")
print(f"  CODE_EMBEDDING_MODEL: {CODE_EMBEDDING_MODEL}")
print(f"  CODE_VLLM_GPU_COUNT: {CODE_VLLM_GPU_COUNT}")
print(f"  MEMORY_PENALTY_THRESHOLD (max): {MEMORY_PENALTY_THRESHOLD}")
print(f"  MEMORY_PENALTY_MEAN_THRESHOLD: {MEMORY_PENALTY_MEAN_THRESHOLD}")
print(f"  MEMORY_PENALTY_GAMMA: {MEMORY_PENALTY_GAMMA}")

# ============ 全局变量：模块级初始化 ============
_nl_embedding_model = None
_nl_embedding_tokenizer = None
_code_embedding_model = None
_code_embedding_tokenizer = None
_memory_bank_embeddings = None
_memory_bank_loaded = False
_memory_stats = {
    "step_max_similarities": [],  # 每个step的max_similarity列表
    "step_mean_similarities": [],  # 每个step的mean_similarity列表
    "step_count": 0
}


# ============ vLLM Sleep Mode Control Functions ============
# 用于时分复用模式下控制vLLM服务的睡眠/唤醒

def wake_up_solver_vllm() -> bool:
    """
    唤醒所有Solver vLLM服务。
    在调用generate_results()之前调用此函数。
    
    Returns:
        是否所有服务都成功唤醒
    """
    print("[vLLM Control] Waking up Solver vLLM services...")
    start_time = time.time()
    
    def wake_single(port: int) -> bool:
        try:
            response = requests.post(f"http://127.0.0.1:{port}/wake_up", timeout=120)
            return response.status_code == 200
        except Exception as e:
            print(f"[vLLM Control] Failed to wake up Solver on port {port}: {e}")
            return False
    
    with ThreadPoolExecutor(max_workers=VLLM_GPU_COUNT) as executor:
        futures = [executor.submit(wake_single, 5000 + i) for i in range(VLLM_GPU_COUNT)]
        results = [f.result() for f in futures]
    
    elapsed = time.time() - start_time
    success = sum(results)
    print(f"[vLLM Control] Solver vLLM wake up complete: {success}/{VLLM_GPU_COUNT} services, time: {elapsed:.2f}s")
    return all(results)


def sleep_solver_vllm() -> bool:
    """
    让所有Solver vLLM服务进入睡眠状态。
    在generate_results()调用完成后调用此函数。
    
    Returns:
        是否所有服务都成功进入睡眠
    """
    print("[vLLM Control] Putting Solver vLLM services to sleep...")
    start_time = time.time()
    
    def sleep_single(port: int) -> bool:
        try:
            response = requests.post(f"http://127.0.0.1:{port}/sleep?level=1", timeout=60)
            return response.status_code == 200
        except Exception as e:
            print(f"[vLLM Control] Failed to sleep Solver on port {port}: {e}")
            return False
    
    with ThreadPoolExecutor(max_workers=VLLM_GPU_COUNT) as executor:
        futures = [executor.submit(sleep_single, 5000 + i) for i in range(VLLM_GPU_COUNT)]
        results = [f.result() for f in futures]
    
    elapsed = time.time() - start_time
    success = sum(results)
    print(f"[vLLM Control] Solver vLLM sleep complete: {success}/{VLLM_GPU_COUNT} services, time: {elapsed:.2f}s")
    
    # 清理GPU缓存
    torch.cuda.empty_cache()
    
    return all(results)


def wake_up_code_vllm() -> bool:
    """
    唤醒所有Code vLLM服务。
    在调用generate_codes_from_questions()之前调用此函数。
    
    Returns:
        是否所有服务都成功唤醒
    """
    print("[vLLM Control] Waking up Code vLLM services...")
    start_time = time.time()
    
    def wake_single(port: int) -> bool:
        try:
            response = requests.post(f"http://127.0.0.1:{port}/wake_up", timeout=120)
            return response.status_code == 200
        except Exception as e:
            print(f"[vLLM Control] Failed to wake up Code on port {port}: {e}")
            return False
    
    with ThreadPoolExecutor(max_workers=CODE_VLLM_GPU_COUNT) as executor:
        futures = [executor.submit(wake_single, CODE_VLLM_BASE_PORT + i) for i in range(CODE_VLLM_GPU_COUNT)]
        results = [f.result() for f in futures]
    
    elapsed = time.time() - start_time
    success = sum(results)
    print(f"[vLLM Control] Code vLLM wake up complete: {success}/{CODE_VLLM_GPU_COUNT} services, time: {elapsed:.2f}s")
    return all(results)


def sleep_code_vllm() -> bool:
    """
    让所有Code vLLM服务进入睡眠状态。
    在generate_codes_from_questions()调用完成后调用此函数。
    
    Returns:
        是否所有服务都成功进入睡眠
    """
    print("[vLLM Control] Putting Code vLLM services to sleep...")
    start_time = time.time()
    
    def sleep_single(port: int) -> bool:
        try:
            response = requests.post(f"http://127.0.0.1:{port}/sleep?level=1", timeout=60)
            return response.status_code == 200
        except Exception as e:
            print(f"[vLLM Control] Failed to sleep Code on port {port}: {e}")
            return False
    
    with ThreadPoolExecutor(max_workers=CODE_VLLM_GPU_COUNT) as executor:
        futures = [executor.submit(sleep_single, CODE_VLLM_BASE_PORT + i) for i in range(CODE_VLLM_GPU_COUNT)]
        results = [f.result() for f in futures]
    
    elapsed = time.time() - start_time
    success = sum(results)
    print(f"[vLLM Control] Code vLLM sleep complete: {success}/{CODE_VLLM_GPU_COUNT} services, time: {elapsed:.2f}s")
    
    # 清理GPU缓存
    torch.cuda.empty_cache()
    
    return all(results)


def _load_nl_embedding_model():
    """
    加载自然语言嵌入模型（BAAI/bge-large-en-v1.5）。
    """
    global _nl_embedding_model, _nl_embedding_tokenizer
    
    if _nl_embedding_model is not None:
        return _nl_embedding_tokenizer, _nl_embedding_model
    
    print(f"[Memory Bank] Loading NL embedding model {NL_EMBEDDING_MODEL}...")
    from transformers import AutoTokenizer, AutoModel
    
    _nl_embedding_tokenizer = AutoTokenizer.from_pretrained(NL_EMBEDDING_MODEL, local_files_only=True)
    _nl_embedding_model = AutoModel.from_pretrained(NL_EMBEDDING_MODEL, local_files_only=True)
    
    try:
        device = torch.device("cuda:0")
        _nl_embedding_model = _nl_embedding_model.to(device)
        _ = torch.zeros(1, device=device)
        print(f"[Memory Bank] NL embedding model loaded on {device}")
    except Exception as e:
        print(f"[Memory Bank] Warning: Failed to load on GPU ({e}), falling back to CPU")
        device = torch.device("cpu")
        _nl_embedding_model = _nl_embedding_model.to(device)
        print(f"[Memory Bank] NL embedding model loaded on {device}")
    
    _nl_embedding_model.eval()
    return _nl_embedding_tokenizer, _nl_embedding_model


def _unload_nl_embedding_model():
    """
    卸载自然语言嵌入模型以释放显存。
    """
    global _nl_embedding_model, _nl_embedding_tokenizer
    
    if _nl_embedding_model is not None:
        del _nl_embedding_model
        _nl_embedding_model = None
        _nl_embedding_tokenizer = None
        torch.cuda.empty_cache()
        print("[Memory Bank] NL embedding model unloaded from GPU")


def _load_code_embedding_model():
    """
    加载代码嵌入模型（jinaai/jina-code-embeddings-1.5b）。
    """
    global _code_embedding_model, _code_embedding_tokenizer
    
    if _code_embedding_model is not None:
        return _code_embedding_tokenizer, _code_embedding_model
    
    print(f"[Memory Bank] Loading Code embedding model {CODE_EMBEDDING_MODEL}...")
    from transformers import AutoTokenizer, AutoModel
    
    _code_embedding_tokenizer = AutoTokenizer.from_pretrained(CODE_EMBEDDING_MODEL, trust_remote_code=True, local_files_only=True)
    _code_embedding_model = AutoModel.from_pretrained(CODE_EMBEDDING_MODEL, trust_remote_code=True, local_files_only=True)
    
    try:
        device = torch.device("cuda:0")
        _code_embedding_model = _code_embedding_model.to(device)
        _ = torch.zeros(1, device=device)
        print(f"[Memory Bank] Code embedding model loaded on {device}")
    except Exception as e:
        print(f"[Memory Bank] Warning: Failed to load on GPU ({e}), falling back to CPU")
        device = torch.device("cpu")
        _code_embedding_model = _code_embedding_model.to(device)
        print(f"[Memory Bank] Code embedding model loaded on {device}")
    
    _code_embedding_model.eval()
    return _code_embedding_tokenizer, _code_embedding_model


def _unload_code_embedding_model():
    """
    卸载代码嵌入模型以释放显存。
    """
    global _code_embedding_model, _code_embedding_tokenizer
    
    if _code_embedding_model is not None:
        del _code_embedding_model
        _code_embedding_model = None
        _code_embedding_tokenizer = None
        torch.cuda.empty_cache()
        print("[Memory Bank] Code embedding model unloaded from GPU")


def _load_embedding_model():
    """
    模块级加载嵌入模型，根据 EMBEDDING_TYPE 选择不同的模型。
    """
    if EMBEDDING_TYPE == "code":
        return _load_code_embedding_model()
    else:
        return _load_nl_embedding_model()


def _unload_embedding_model():
    """
    卸载嵌入模型以释放显存，根据 EMBEDDING_TYPE 选择不同的卸载方式。
    这个函数应该在每次 batch penalty 和 memory penalty 计算完成后调用，
    以避免与训练进程争抢 GPU 显存。
    """
    if EMBEDDING_TYPE == "code":
        _unload_code_embedding_model()
    else:
        _unload_nl_embedding_model()


def _load_memory_bank() -> Optional[np.ndarray]:
    """
    加载 Memory Bank 的嵌入向量，首次调用时加载并缓存。
    根据 EMBEDDING_TYPE 加载不同的文件。
    
    Returns:
        Memory Bank embeddings array or None if not available
    """
    global _memory_bank_embeddings, _memory_bank_loaded
    
    if _memory_bank_loaded:
        return _memory_bank_embeddings
    
    # 根据 EMBEDDING_TYPE 选择不同的文件
    if EMBEDDING_TYPE == "code":
        embeddings_path = os.path.join(MEMORY_BANK_PATH, "embedding_code.npy")
    else:
        embeddings_path = os.path.join(MEMORY_BANK_PATH, "embeddings.npy")
    
    print(f"[Memory Bank] Loading embeddings from {embeddings_path}...")
    
    if os.path.exists(embeddings_path):
        try:
            _memory_bank_embeddings = np.load(embeddings_path)
            print(f"[Memory Bank] Loaded embeddings with shape {_memory_bank_embeddings.shape}")
        except Exception as e:
            print(f"[Memory Bank] Error loading embeddings: {e}")
            _memory_bank_embeddings = None
    else:
        print(f"[Memory Bank] No embeddings found at {embeddings_path}, Memory-based Penalty will be 0")
        _memory_bank_embeddings = None
    
    _memory_bank_loaded = True
    return _memory_bank_embeddings


def compute_question_embeddings(questions: List[str]) -> np.ndarray:
    """
    计算问题（自然语言）的嵌入向量。
    
    Args:
        questions: 问题文本列表
        
    Returns:
        归一化后的嵌入向量 (n_questions, embedding_dim)
    """
    if not questions:
        return np.array([]).reshape(0, 1024)
    
    tokenizer, model = _load_nl_embedding_model()
    device = next(model.parameters()).device
    
    # 使用 BGE 推荐的前缀
    prefix = "Represent this question for similarity search: "
    prefixed_questions = [prefix + q for q in questions]
    
    all_embeddings = []
    batch_size = 32
    
    with torch.no_grad():
        for i in range(0, len(prefixed_questions), batch_size):
            batch = prefixed_questions[i:i + batch_size]
            
            encoded = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            encoded = {k: v.to(device) for k, v in encoded.items()}
            
            outputs = model(**encoded)
            embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            
            all_embeddings.append(embeddings.cpu().numpy())
    
    return np.vstack(all_embeddings)


def compute_code_embeddings(codes: List[str]) -> np.ndarray:
    """
    计算代码的嵌入向量。
    
    Args:
        codes: Python代码列表
        
    Returns:
        归一化后的嵌入向量 (n_codes, embedding_dim)
    """
    if not codes:
        return np.array([]).reshape(0, 1536)  # jina-code-embeddings 维度
    
    tokenizer, model = _load_code_embedding_model()
    device = next(model.parameters()).device
    
    all_embeddings = []
    batch_size = 16  # 代码通常更长，使用较小batch
    
    with torch.no_grad():
        for i in range(0, len(codes), batch_size):
            batch = codes[i:i + batch_size]
            
            encoded = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=1024,  # 代码可能更长
                return_tensors="pt"
            )
            encoded = {k: v.to(device) for k, v in encoded.items()}
            
            outputs = model(**encoded)
            # Mean pooling for jina-code-embeddings
            attention_mask = encoded['attention_mask']
            embeddings = outputs.last_hidden_state
            embeddings = (embeddings * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1, keepdim=True)
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            
            all_embeddings.append(embeddings.cpu().numpy())
    
    return np.vstack(all_embeddings)


def generate_codes_from_questions(questions: List[str], max_retries: int = 3) -> List[Optional[str]]:
    """
    批量调用代码生成服务，将问题转换为Python解题代码。
    
    Args:
        questions: 问题列表
        max_retries: 最大重试次数
        
    Returns:
        代码列表，生成失败的位置为None
    """
    if not questions:
        return []
    
    print(f"[Code Generation] Converting {len(questions)} questions to Python code...")
    start_time = time.time()
    
    # 构建请求数据
    n_services = CODE_VLLM_GPU_COUNT
    
    # 将问题分配给不同的服务
    questions_per_service = [[] for _ in range(n_services)]
    indices_per_service = [[] for _ in range(n_services)]
    
    for i, q in enumerate(questions):
        service_idx = i % n_services
        questions_per_service[service_idx].append(q)
        indices_per_service[service_idx].append(i)
    
    # 并发调用服务
    codes = [None] * len(questions)
    
    def call_service(service_idx: int):
        qs = questions_per_service[service_idx]
        if not qs:
            return []
        
        port = CODE_VLLM_BASE_PORT + service_idx
        url = f"http://127.0.0.1:{port}/generate_with_retry"
        
        try:
            response = requests.post(
                url,
                json={"questions": qs, "max_retries": max_retries},
                timeout=300  # 5分钟超时
            )
            if response.status_code == 200:
                result = response.json()
                return result.get('codes', [None] * len(qs))
            else:
                print(f"[Code Generation] Service {service_idx} error: {response.status_code}")
                return [None] * len(qs)
        except Exception as e:
            print(f"[Code Generation] Service {service_idx} exception: {e}")
            return [None] * len(qs)
    
    with ThreadPoolExecutor(max_workers=n_services) as executor:
        futures = {executor.submit(call_service, i): i for i in range(n_services)}
        
        for future in as_completed(futures):
            service_idx = futures[future]
            try:
                result_codes = future.result()
                # 将结果放回原始位置
                for j, code in enumerate(result_codes):
                    if j < len(indices_per_service[service_idx]):
                        orig_idx = indices_per_service[service_idx][j]
                        codes[orig_idx] = code
            except Exception as e:
                print(f"[Code Generation] Error processing service {service_idx}: {e}")
    
    # 统计结果
    success_count = sum(1 for c in codes if c is not None)
    failed_count = len(codes) - success_count
    elapsed = time.time() - start_time
    
    print(f"[Code Generation] Complete: {success_count} success, {failed_count} failed, time: {elapsed:.2f}s")
    
    return codes


def compute_memory_based_penalty_nl(
    questions: List[str],
    threshold_max: float = 0.5,
    threshold_mean: float = 0.3,
    gamma: float = 1.0
) -> Tuple[List[float], Dict[str, List[float]]]:
    """
    计算基于自然语言的 Memory-based Penalty。
    
    组合 max_similarity 和 mean_similarity 两个指标：
    - max_similarity: 与历史中最相似问题的相似度（防止精确重复）
    - mean_similarity: 与所有历史问题的平均相似度（防止落入密集区域）
    
    Args:
        questions: 问题列表
        threshold_max: max_similarity 的阈值
        threshold_mean: mean_similarity 的阈值
        gamma: max_penalty 的权重 (1.0=只用max, 0.0=只用mean, 0.5=各占一半)
        
    Returns:
        (penalties, stats_dict)
        - penalties: 组合后的惩罚值列表
        - stats_dict: 包含详细统计信息的字典
    """
    memory_embeddings = _load_memory_bank()
    
    n = len(questions)
    empty_stats = {
        'max_similarities': [0.0] * n,
        'mean_similarities': [0.0] * n,
        'max_penalties': [0.0] * n,
        'mean_penalties': [0.0] * n,
    }
    
    if memory_embeddings is None or memory_embeddings.size == 0:
        return [0.0] * n, empty_stats
    
    valid_indices = []
    valid_questions = []
    for i, q in enumerate(questions):
        if q and q.strip():
            valid_indices.append(i)
            valid_questions.append(q)
    
    penalties = [0.0] * n
    stats = {
        'max_similarities': [0.0] * n,
        'mean_similarities': [0.0] * n,
        'max_penalties': [0.0] * n,
        'mean_penalties': [0.0] * n,
    }
    
    if not valid_questions:
        return penalties, stats
    
    # 计算相似度矩阵
    question_embeddings = compute_question_embeddings(valid_questions)
    similarities = np.dot(question_embeddings, memory_embeddings.T)  # (n_valid, n_memory)
    
    # 计算 max_similarity 和 mean_similarity
    valid_max_sims = np.max(similarities, axis=1)
    valid_mean_sims = np.mean(similarities, axis=1)
    
    # 计算两种 penalty
    valid_max_penalties = np.maximum(0, valid_max_sims - threshold_max)
    valid_mean_penalties = np.maximum(0, valid_mean_sims - threshold_mean)
    
    # 组合 penalty: gamma * max_penalty + (1-gamma) * mean_penalty
    valid_combined_penalties = gamma * valid_max_penalties + (1 - gamma) * valid_mean_penalties
    
    # 填充结果
    for idx, orig_idx in enumerate(valid_indices):
        penalties[orig_idx] = float(valid_combined_penalties[idx])
        stats['max_similarities'][orig_idx] = float(valid_max_sims[idx])
        stats['mean_similarities'][orig_idx] = float(valid_mean_sims[idx])
        stats['max_penalties'][orig_idx] = float(valid_max_penalties[idx])
        stats['mean_penalties'][orig_idx] = float(valid_mean_penalties[idx])
    
    return penalties, stats


def compute_memory_based_penalty_code(
    codes: List[Optional[str]],
    threshold_max: float = 0.5,
    threshold_mean: float = 0.3,
    gamma: float = 1.0
) -> Tuple[List[float], Dict[str, List[float]]]:
    """
    计算基于代码的 Memory-based Penalty。
    
    组合 max_similarity 和 mean_similarity 两个指标：
    - max_similarity: 与历史中最相似代码的相似度（防止精确重复）
    - mean_similarity: 与所有历史代码的平均相似度（防止落入密集区域）
    
    Args:
        codes: 代码列表，None表示生成失败
        threshold_max: max_similarity 的阈值
        threshold_mean: mean_similarity 的阈值
        gamma: max_penalty 的权重 (1.0=只用max, 0.0=只用mean, 0.5=各占一半)
        
    Returns:
        (penalties, stats_dict)
        - penalties: 组合后的惩罚值列表
        - stats_dict: 包含详细统计信息的字典
    """
    memory_embeddings = _load_memory_bank()
    
    n = len(codes)
    empty_stats = {
        'max_similarities': [0.0] * n,
        'mean_similarities': [0.0] * n,
        'max_penalties': [0.0] * n,
        'mean_penalties': [0.0] * n,
    }
    
    if memory_embeddings is None or memory_embeddings.size == 0:
        return [0.0] * n, empty_stats
    
    valid_indices = []
    valid_codes = []
    for i, c in enumerate(codes):
        if c and c.strip():
            valid_indices.append(i)
            valid_codes.append(c)
    
    penalties = [0.0] * n
    stats = {
        'max_similarities': [0.0] * n,
        'mean_similarities': [0.0] * n,
        'max_penalties': [0.0] * n,
        'mean_penalties': [0.0] * n,
    }
    
    if not valid_codes:
        return penalties, stats
    
    # 计算相似度矩阵
    code_embeddings = compute_code_embeddings(valid_codes)
    similarities = np.dot(code_embeddings, memory_embeddings.T)  # (n_valid, n_memory)
    
    # 计算 max_similarity 和 mean_similarity
    valid_max_sims = np.max(similarities, axis=1)
    valid_mean_sims = np.mean(similarities, axis=1)
    
    # 计算两种 penalty
    valid_max_penalties = np.maximum(0, valid_max_sims - threshold_max)
    valid_mean_penalties = np.maximum(0, valid_mean_sims - threshold_mean)
    
    # 组合 penalty: gamma * max_penalty + (1-gamma) * mean_penalty
    valid_combined_penalties = gamma * valid_max_penalties + (1 - gamma) * valid_mean_penalties
    
    # 填充结果
    for idx, orig_idx in enumerate(valid_indices):
        penalties[orig_idx] = float(valid_combined_penalties[idx])
        stats['max_similarities'][orig_idx] = float(valid_max_sims[idx])
        stats['mean_similarities'][orig_idx] = float(valid_mean_sims[idx])
        stats['max_penalties'][orig_idx] = float(valid_max_penalties[idx])
        stats['mean_penalties'][orig_idx] = float(valid_mean_penalties[idx])
    
    return penalties, stats


def compute_memory_based_penalty(
    questions: List[str],
    codes: Optional[List[Optional[str]]] = None,
    threshold_max: float = 0.5,
    threshold_mean: float = 0.3,
    gamma: float = 1.0
) -> Tuple[List[float], Dict[str, List[float]]]:
    """
    计算 Memory-based Penalty，根据 EMBEDDING_TYPE 选择不同的计算方式。
    
    组合 max_similarity 和 mean_similarity 两个指标：
    - max_similarity: 与历史中最相似问题/代码的相似度（防止精确重复）
    - mean_similarity: 与所有历史问题/代码的平均相似度（防止落入密集区域）
    
    Args:
        questions: 问题列表
        codes: 代码列表（code模式时使用）
        threshold_max: max_similarity 的阈值
        threshold_mean: mean_similarity 的阈值
        gamma: max_penalty 的权重 (1.0=只用max, 0.0=只用mean, 0.5=各占一半)
        
    Returns:
        (penalties, stats_dict)
        - penalties: 组合后的惩罚值列表
        - stats_dict: 包含详细统计信息的字典（max_similarities, mean_similarities, max_penalties, mean_penalties）
    """
    if EMBEDDING_TYPE == "code" and codes is not None:
        return compute_memory_based_penalty_code(codes, threshold_max, threshold_mean, gamma)
    else:
        return compute_memory_based_penalty_nl(questions, threshold_max, threshold_mean, gamma)


def record_memory_stats(stats: Dict[str, List[float]]):
    """
    记录 Memory Bank 统计信息，用于 wandb 日志。
    
    Args:
        stats: 包含 max_similarities, mean_similarities 等的字典
    """
    global _memory_stats
    
    max_sims = stats.get('max_similarities', [])
    mean_sims = stats.get('mean_similarities', [])
    
    valid_max_sims = [s for s in max_sims if s > 0]
    valid_mean_sims = [s for s in mean_sims if s > 0]
    
    if valid_max_sims:
        _memory_stats["step_max_similarities"].extend(valid_max_sims)
    if valid_mean_sims:
        _memory_stats["step_mean_similarities"].extend(valid_mean_sims)
    
    _memory_stats["step_count"] += 1


def get_and_reset_memory_stats() -> Dict[str, float]:
    """
    获取并重置 Memory Bank 统计信息。
    """
    global _memory_stats
    
    stats = {}
    
    # Max similarity 统计
    if _memory_stats.get("step_max_similarities"):
        stats["memory/avg_max_similarity"] = np.mean(_memory_stats["step_max_similarities"])
        stats["memory/std_max_similarity"] = np.std(_memory_stats["step_max_similarities"])
        stats["memory/max_max_similarity"] = np.max(_memory_stats["step_max_similarities"])
        stats["memory/min_max_similarity"] = np.min(_memory_stats["step_max_similarities"])
    
    # Mean similarity 统计 (新增)
    if _memory_stats.get("step_mean_similarities"):
        stats["memory/avg_mean_similarity"] = np.mean(_memory_stats["step_mean_similarities"])
        stats["memory/std_mean_similarity"] = np.std(_memory_stats["step_mean_similarities"])
        stats["memory/max_mean_similarity"] = np.max(_memory_stats["step_mean_similarities"])
        stats["memory/min_mean_similarity"] = np.min(_memory_stats["step_mean_similarities"])
    
    _memory_stats = {
        "step_max_similarities": [],
        "step_mean_similarities": [],
        "step_count": 0
    }
    
    return stats


def _bleu_distance_matrix(sentences):
    n = len(sentences)
    dist = np.zeros((n, n))
    smoother = SmoothingFunction().method1
    for i in range(n):
        for j in range(i, n):
            if i == j:
                score = 1.0
            else:
                ref = [sentences[j].split()]
                hyp = sentences[i].split()
                score = sentence_bleu(ref, hyp, smoothing_function=smoother)
            dist[i, j] = dist[j, i] = 1 - score
    return dist


def _code_embedding_distance_matrix(codes: List[str]) -> np.ndarray:
    """
    计算代码嵌入的距离矩阵。
    使用余弦相似度转换为距离。
    """
    embeddings = compute_code_embeddings(codes)
    # 余弦相似度矩阵
    similarities = np.dot(embeddings, embeddings.T)
    # 转换为距离矩阵
    distances = 1 - similarities
    # 确保对角线为0
    np.fill_diagonal(distances, 0)
    return distances


def cluster_share_per_problem(
        problems,
        distance_threshold: float = 0.5,
        linkage: str = "average"):
    """
    基于自然语言的聚类（使用BLEU距离）。
    """
    if not problems:
        return []
    print('[Batch Penalty] Starting NL clustering (BLEU distance)')
    start_time = time.time()
    dist_mat = _bleu_distance_matrix(problems)

    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=distance_threshold,
        metric="precomputed",
        linkage=linkage
    )
    labels = clustering.fit_predict(dist_mat)
    print(f'[Batch Penalty] NL clustering complete, time: {time.time() - start_time:.2f}s')
    total = len(problems)
    cluster_size = Counter(labels)
    cluster_ratio = {lab: sz / total for lab, sz in cluster_size.items()}

    proportions = [cluster_ratio[lab] for lab in labels]
    return proportions


def cluster_share_per_code(
        codes: List[Optional[str]],
        distance_threshold: float = 0.3,
        linkage: str = "average") -> List[float]:
    """
    基于代码嵌入的聚类。
    
    Args:
        codes: Python代码列表，None表示生成失败
        distance_threshold: 聚类距离阈值
        linkage: 聚类链接方法
        
    Returns:
        每个代码所属簇的比例（用作penalty）
    """
    if not codes:
        return []
    
    # 过滤有效代码
    valid_indices = []
    valid_codes = []
    for i, c in enumerate(codes):
        if c and c.strip():
            valid_indices.append(i)
            valid_codes.append(c)
    
    if not valid_codes:
        return [1.0] * len(codes)  # 全部无效，给最大penalty
    
    print(f'[Batch Penalty] Starting Code clustering ({len(valid_codes)} valid codes)')
    start_time = time.time()
    
    # 计算代码嵌入距离矩阵
    dist_mat = _code_embedding_distance_matrix(valid_codes)
    
    # 如果只有一个有效代码，直接返回
    if len(valid_codes) == 1:
        proportions = [1.0] * len(codes)
        proportions[valid_indices[0]] = 1.0
        return proportions
    
    clustering = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=distance_threshold,
        metric="precomputed",
        linkage=linkage
    )
    labels = clustering.fit_predict(dist_mat)
    print(f'[Batch Penalty] Code clustering complete, time: {time.time() - start_time:.2f}s')
    
    total = len(valid_codes)
    cluster_size = Counter(labels)
    cluster_ratio = {lab: sz / total for lab, sz in cluster_size.items()}
    
    # 初始化结果（无效代码给penalty 1.0）
    proportions = [1.0] * len(codes)
    
    # 填充有效代码的结果
    for i, orig_idx in enumerate(valid_indices):
        proportions[orig_idx] = cluster_ratio[labels[i]]
    
    return proportions


def generate_temp_filename(prefix="temp", suffix=".json"):
    timestamp = int(time.time() * 1000)
    rand_part = random.randint(0, 99999)
    return f"{STORAGE_PATH}/temp_results/{prefix}_{timestamp}_{rand_part}{suffix}"


def split_list(lst, n=None):
    if n is None:
        n = VLLM_GPU_COUNT
    k, m = divmod(len(lst), n)
    return [lst[i*k + min(i, m):(i+1)*k + min(i+1, m)] for i in range(n)]


os.environ["NO_PROXY"] = "0.0.0.0,127.0.0.1"


def fetch(index, i):
    response = requests.get(f"http://0.0.0.0:{5000+index}/hello?name={i}")
    print(response)
    return True


def generate_results(data):
    n_services = VLLM_GPU_COUNT
    datas = split_list(data, n_services)
    random_names = [generate_temp_filename(prefix=f"temp_{i}", suffix=".json") for i in range(n_services)]
    for i in range(n_services):
        with open(random_names[i], 'w') as f:
            json.dump(datas[i], f, indent=4)

    final_results = []
    with ThreadPoolExecutor(max_workers=n_services) as executor:
        futures = [executor.submit(fetch, i, random_names[i]) for i in range(n_services)]

        for future in as_completed(futures):
            print(future.result())

    for i in range(n_services):
        with open(random_names[i].replace('.json', '_results.json'), 'r') as f:
            final_results.extend(json.load(f))
    for i in range(n_services):
        os.remove(random_names[i].replace('.json', '_results.json'))
    return final_results


def format_reward(predict: str) -> float:
    pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
    format_match = re.fullmatch(pattern, predict)
    return 1.0 if format_match else 0.0


def accuracy_reward(predict: str, ground_truth: str) -> float:
    answer = extract_boxed_content(predict)
    return 1.0 if grade_answer(answer, ground_truth) else 0.0


def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1, file_path: str = "") -> List[Dict[str, float]]:
    """
    计算Challenger生成问题的reward score。
    
    流程：
    1. 提取问题和答案
    2. 调用Solver服务获取不确定性分数
    3. 如果是code模式，将问题转换为Python代码
    4. 计算Batch Penalty和Memory Penalty
    5. 组合得到最终score
    """
    results = []
    with open('test.json', 'w') as f:
        json.dump(predicts, f, indent=4)
    
    for i in range(len(predicts)):
        questions = re.findall(r"<question>(.*?)</question>", predicts[i], re.DOTALL)
        answers = extract_boxed_content(predicts[i])
        if questions and answers:
            try:
                question = questions[-1].strip()
                answer = answers[-1].strip()
                results.append({"question": question, "answer": answer})
            except:
                results.append({"question": "", "answer": ""})
        else:
            results.append({"question": "", "answer": ""})

    # ============ Step 1: 调用Solver服务获取不确定性分数 ============
    _t0 = time.perf_counter()
    print(f"[caller_penalty] results length: {len(results)}")
    
    # 时分复用模式：唤醒Solver vLLM服务
    wake_up_solver_vllm()
    
    final_results = generate_results(results)
    
    # 时分复用模式：让Solver vLLM服务进入睡眠
    sleep_solver_vllm()
    
    _t1 = time.perf_counter()
    t_generate = _t1 - _t0
    print(f"[caller_penalty] generate_results 耗时: {t_generate:.6f}s")

    # ============ Step 2: 如果是code模式，生成Python代码 ============
    generated_codes = None
    t_code_gen = 0
    
    # 只有在需要 code embedding 时才生成代码
    need_code_for_batch = (EMBEDDING_TYPE == "code" and not KEEP_BATCH_PENALTY_UNCHANGED and PENALTY_ALPHA > 0)
    need_code_for_memory = (EMBEDDING_TYPE == "code" and PENALTY_BETA > 0)
    
    if need_code_for_batch or need_code_for_memory:
        _t_code_start = time.perf_counter()
        questions_for_code = [result['question'] for result in final_results]
        
        # 时分复用模式：唤醒Code vLLM服务
        wake_up_code_vllm()
        
        generated_codes = generate_codes_from_questions(questions_for_code, max_retries=3)
        
        # 时分复用模式：让Code vLLM服务进入睡眠
        sleep_code_vllm()
        
        _t_code_end = time.perf_counter()
        t_code_gen = _t_code_end - _t_code_start
        print(f"[caller_penalty] Code generation 耗时: {t_code_gen:.6f}s")
        
        # 统计代码生成成功率
        code_success = sum(1 for c in generated_codes if c is not None)
        code_failed = len(generated_codes) - code_success
        print(f"[caller_penalty] Code generation: {code_success} success, {code_failed} failed")
    elif EMBEDDING_TYPE == "code":
        print(f"[caller_penalty] Skipping code generation since neither batch nor memory penalty need it")
        print(f"[caller_penalty] EMBEDDING_TYPE={EMBEDDING_TYPE}, KEEP_BATCH_PENALTY_UNCHANGED={KEEP_BATCH_PENALTY_UNCHANGED}, PENALTY_ALPHA={PENALTY_ALPHA}, PENALTY_BETA={PENALTY_BETA}")

    # ============ Step 3: 计算 Batch-based Penalty ============
    if PENALTY_ALPHA > 0:
        _t2 = time.perf_counter()
        questions_for_batch = [result['question'] for result in final_results]
        print(f"[caller_penalty] Computing batch penalty for {len(questions_for_batch)} items")
        
        # 根据配置决定batch penalty的计算方式
        if KEEP_BATCH_PENALTY_UNCHANGED or EMBEDDING_TYPE == "nl":
            # 使用自然语言（BLEU距离）计算
            batch_penalty = cluster_share_per_problem(questions_for_batch, distance_threshold=0.5)
        else:
            # 使用代码嵌入计算（仅当code模式且不保持不变时）
            batch_penalty = cluster_share_per_code(generated_codes, distance_threshold=0.3)
        
        _t3 = time.perf_counter()
        t_cluster = _t3 - _t2
        print(f"[caller_penalty] Batch penalty computation 耗时: {t_cluster:.6f}s")
    else:
        # PENALTY_ALPHA == 0，跳过 batch penalty 计算
        t_cluster = 0
        batch_penalty = [0.0] * len(final_results)
        print(f"[caller_penalty] Skipping batch penalty computation (PENALTY_ALPHA={PENALTY_ALPHA})")
    
    # ============ Step 4: 计算 Memory-based Penalty ============
    if PENALTY_BETA > 0:
        _t4 = time.perf_counter()
        questions_for_memory = [result['question'] for result in final_results]
        memory_penalty, memory_stats = compute_memory_based_penalty(
            questions_for_memory,
            codes=generated_codes,
            threshold_max=MEMORY_PENALTY_THRESHOLD,
            threshold_mean=MEMORY_PENALTY_MEAN_THRESHOLD,
            gamma=MEMORY_PENALTY_GAMMA
        )
        _t5 = time.perf_counter()
        t_memory = _t5 - _t4
        print(f"[caller_penalty] Memory-based penalty 耗时: {t_memory:.6f}s")
        
        # 记录统计信息
        record_memory_stats(memory_stats)
        
        # 提取相似度信息用于日志和最终结果
        max_similarities = memory_stats['max_similarities']
        mean_similarities = memory_stats['mean_similarities']
        max_penalties = memory_stats['max_penalties']
        mean_penalties = memory_stats['mean_penalties']
        
        # 计算加权后的penalty贡献（用于wandb显示）
        # weighted_max_penalty = gamma * max_penalty
        # weighted_mean_penalty = (1-gamma) * mean_penalty
        weighted_max_penalties = [MEMORY_PENALTY_GAMMA * p for p in max_penalties]
        weighted_mean_penalties = [(1 - MEMORY_PENALTY_GAMMA) * p for p in mean_penalties]
        
        # 计算平均值用于日志
        valid_max_sims = [s for s in max_similarities if s > 0]
        valid_mean_sims = [s for s in mean_similarities if s > 0]
        avg_max_sim = np.mean(valid_max_sims) if valid_max_sims else 0.0
        avg_mean_sim = np.mean(valid_mean_sims) if valid_mean_sims else 0.0
        
        print(f"[caller_penalty] Memory-based Penalty: "
              f"avg_max_sim={avg_max_sim:.4f}, avg_mean_sim={avg_mean_sim:.4f}, "
              f"gamma={MEMORY_PENALTY_GAMMA}, avg_penalty={np.mean(memory_penalty):.4f}")
    else:
        # PENALTY_BETA == 0，跳过 memory penalty 计算
        t_memory = 0
        memory_penalty = [0.0] * len(final_results)
        max_similarities = [0.0] * len(final_results)
        mean_similarities = [0.0] * len(final_results)
        weighted_max_penalties = [0.0] * len(final_results)
        weighted_mean_penalties = [0.0] * len(final_results)
        print(f"[caller_penalty] Skipping memory penalty computation (PENALTY_BETA={PENALTY_BETA})")
    
    # ============ 卸载 Embedding Model 以释放显存 ============
    # 无论是否计算了 memory penalty，只要加载了 embedding model 就需要卸载
    # 以避免与训练进程争抢 GPU 显存导致 OOM
    # 注意：_unload_*_embedding_model() 会检查 model 是否为 None，所以调用是安全的
    _unload_embedding_model()
    
    _total = t_generate + t_code_gen + t_cluster + t_memory
    if _total > 0:
        print(
            f"[caller_penalty] 耗时占比: "
            f"solver={((t_generate / _total)):.2%}, "
            f"code_gen={((t_code_gen / _total)):.2%}, "
            f"batch_penalty={((t_cluster / _total)):.2%}, "
            f"memory_penalty={((t_memory / _total)):.2%}"
        )
    
    assert len(batch_penalty) == len(final_results)
    assert len(memory_penalty) == len(final_results)
    
    # ============ Step 5: 组合得到最终score ============
    scores = []
    for i in range(len(final_results)):
        # 检查是否需要跳过该样本（code模式下代码生成失败）
        should_skip = False
        if EMBEDDING_TYPE == "code" and generated_codes is not None:
            if generated_codes[i] is None:
                should_skip = True
        
        if should_skip:
            # 代码生成失败，给予最低分数使其不参与训练
            scores.append({
                "overall": -10.0,  # 非常低的分数
                "format": 0,
                "accuracy": 0.0,
                "batch_penalty": 0.0,
                "memory_penalty": 0.0,
                "weighted_max_penalty": 0.0,
                "weighted_mean_penalty": 0.0,
                "max_similarity": 0.0,
                "mean_similarity": 0.0,
                "code_gen_failed": True
            })
            continue
        
        # 基础 reward: 不确定性 reward
        base_reward = (min(final_results[i]["score"], 1 - final_results[i]["score"]) 
                       if final_results[i]['question'] else -1)
        
        # 组合 penalty: total_penalty = α * batch_penalty + β * memory_penalty
        total_penalty = (PENALTY_ALPHA * batch_penalty[i] + 
                         PENALTY_BETA * memory_penalty[i])
        
        final_score = base_reward - total_penalty
        
        
        # 论文中的 Reward
        # if final_results[i]['question']:
        #     base_reward = 2 * min(final_results[i]["score"], 1 - final_results[i]["score"])
        #     total_penalty = (PENALTY_ALPHA * batch_penalty[i] + 
        #                     PENALTY_BETA * memory_penalty[i])
        #     final_score = max(0, base_reward - total_penalty)            
        # else:
        #     final_score = 0
        
        
        scores.append({
            "overall": final_score,
            "format": 1 if final_results[i]['question'] else 0,
            "accuracy": batch_penalty[i],  # 保持原始字段名以兼容
            "batch_penalty": batch_penalty[i],
            "memory_penalty": memory_penalty[i],
            "weighted_max_penalty": weighted_max_penalties[i],      # γ * max_penalty
            "weighted_mean_penalty": weighted_mean_penalties[i],  
            "max_similarity": max_similarities[i],
            "mean_similarity": mean_similarities[i],# (1-γ) * mean_penalty
            "code_gen_failed": False
        })
    
    return scores
