# Constants for algorithm configuration
INF = 10  # Used for initialization of min/max values
TEMPERATURE = 0.3  # Temperature parameter for softmax sampling


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Phi-Decoding Algorithm")

    # Model configuration
    parser.add_argument('--model_id', type=str, default='llm-base',
                        help='Model identifier')
    parser.add_argument('--model_path', type=str, default='./models/default',
                        help='Model path')
    parser.add_argument('--gpus', type=int, default=1,
                        help='Number of GPUs to use')

    # Data configuration
    parser.add_argument('--datasets', type=str, default='default_ds',
                        help='Dataset type')  # 可扩展：gsm, math, reclor, logiqa, gpqa, arc
    parser.add_argument('--data_path', type=str,
                        default='./data/input.json',
                        help='Path to input data')
    parser.add_argument('--output_dir', type=str,
                        default='./results/',
                        help='Output directory for results')


    # Algorithm parameters
    parser.add_argument('--step_beam_size', type=int, default=4,
                        help='Beam size for each step')
    parser.add_argument('--num_rollout', type=int, default=10,
                        help='Number of rollouts')
    parser.add_argument('--num_foresight', type=int, default=8,
                        help='Number of foresight steps')
    parser.add_argument('--strategy', type=str, default='cluster',
                        help='Response selection strategy')
    parser.add_argument('--width_pruning_strategy', type=str, default='low_sigma',
                        help='Width pruning strategy')
    parser.add_argument('--depth_pruning_strategy', type=str, default='cluster',
                        help='Depth pruning strategy')
    parser.add_argument('--cluster_num', type=int, default=2,
                        help='Number of clusters for clustering strategy')
    parser.add_argument('--threshold', type=float, default=0.75,
                        help='Threshold for early stopping')
    parser.add_argument('--least_foresight_num', type=int, default=4,
                        help='Minimum number of foresight steps')
    parser.add_argument('--sigma_rate', type=float, default=0.8, help='Sigma rate for width pruning')

    # Execution configuration
    parser.add_argument('--record_process', type=bool, default=True, help='Whether to record the decoding process')
    parser.add_argument('--file_name', type=str, default='test_3',
                        help='Output file name')
    parser.add_argument('--time_path', type=str,
                        default='./results/time/',
                        help='Path to save timing information')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--max_examples', type=int, default=50)
    parser.add_argument('--shot_mode', type=str, default='fewshot', choices=['zeroshot', 'fewshot'],
                        help='选择zeroshot或fewshot模式')
    return parser.parse_args()


def softmax(x):
    """
    Compute softmax values for the input array
    Args:
        x: Input array of values
    Returns:
        Softmax probabilities
    """
    e_x = np.exp(np.array(x))
    return e_x / e_x.sum(axis=0)


def segment_refinement(segment_tokens, hidden_states, halu_scores, global_context, llm_refine=None):
    """
    Enhance segment tokens by combining token-level hallucination scores,
    contextual consistency, logical coherence, and optional LLM refinement.

    Args:
        segment_tokens: List[str], token sequence of the current segment
        hidden_states: List[np.array], hidden state representations for each token
        halu_scores: List[float], hallucination score for each token (the lower, the more reliable)
        global_context: str, global context for consistency checking
        llm_refine: Optional[Callable], further optimization function for tokens

    Returns:
        dict: {"status": "ok"|"discarded"|"refined", "tokens": List[str]}
    """
    if not segment_tokens:
        return {"status": "discarded", "tokens": []}

    # 1️⃣ Initial filtering based on halu_scores
    threshold = 0.3  # Adjustable parameter
    filtered_tokens = [tok for tok, score in zip(segment_tokens, halu_scores) if score >= threshold]

    if not filtered_tokens:
        return {"status": "discarded", "tokens": []}

    # 2️⃣ Contextual consistency check: simply compute semantic similarity between tokens and global_context
    try:
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity

        corpus = [" ".join(filtered_tokens), global_context]
        tfidf = TfidfVectorizer().fit_transform(corpus)
        sim = cosine_similarity(tfidf[0:1], tfidf[1:2])[0][0]
    except Exception:
        sim = 0.5  # Use default similarity if error occurs

    # If similarity with global context is too low, mark as discarded
    if sim < 0.2:
        return {"status": "discarded", "tokens": []}

    # 3️⃣ Optionally call LLM refine for further optimization
    if llm_refine is not None:
        refined_tokens = llm_refine(filtered_tokens)
        return {"status": "refined", "tokens": refined_tokens}
    else:
        # Default optimization: remove too short tokens + merge consecutive low-score tokens
        optimized_tokens = []
        buffer = []
        for tok, score in zip(filtered_tokens, halu_scores):
            if score < threshold:
                if buffer:
                    optimized_tokens.append(" ".join(buffer))
                    buffer = []
            else:
                buffer.append(tok)
        if buffer:
            optimized_tokens.append(" ".join(buffer))
        return {"status": "ok", "tokens": optimized_tokens}

def default_llm_refine(tokens, halu_scores=None, global_context=None):
    """
    Default LLM refine implementation, combining token-level hallucination scores,
    contextual consistency, and simple logical optimization.

    Args:
        tokens: List[str], tokens to be optimized
        halu_scores: Optional[List[float]], hallucination scores for tokens (the higher, the more reliable)
        global_context: Optional[str], global context for semantic consistency checking

    Returns:
        List[str]: optimized tokens
    """
    if not tokens:
        return []

    # 1️⃣ Remove tokens with length less than 2
    tokens = [t for t in tokens if len(t) > 1]

    # 2️⃣ Filter out low-confidence tokens based on halu_scores
    if halu_scores is not None and len(halu_scores) == len(tokens):
        threshold = 0.3
        tokens = [tok for tok, score in zip(tokens, halu_scores) if score >= threshold]

    # 3️⃣ Simple contextual consistency check (optional)
    if global_context is not None and tokens:
        try:
            from sklearn.feature_extraction.text import TfidfVectorizer
            from sklearn.metrics.pairwise import cosine_similarity

            corpus = [" ".join(tokens), global_context]
            tfidf = TfidfVectorizer().fit_transform(corpus)
            sim = cosine_similarity(tfidf[0:1], tfidf[1:2])[0][0]
            # If similarity is too low, try to keep only the token most relevant to the context
            if sim < 0.2:
                # Simple strategy: keep the longest token as the core information
                tokens = [max(tokens, key=len)]
        except Exception:
            pass  # Ignore contextual consistency if error occurs

    # 4️⃣ Merge consecutive short tokens into reasonable segments
    optimized_tokens = []
    buffer = []
    for tok in tokens:
        if len(tok) < 3:
            buffer.append(tok)
        else:
            if buffer:
                optimized_tokens.append(" ".join(buffer))
                buffer = []
            optimized_tokens.append(tok)
    if buffer:
        optimized_tokens.append(" ".join(buffer))

    return optimized_tokens



class PhiDecoder:
    """
    Main class for phi-decoding algorithm implementation.
    Combines clustering and sampling strategies for response selection.
    """

    def __init__(self, args):
        """
        Initialize the decoder
        Args:
            args: Command line arguments containing configuration
        """
        self.args = args
        self.model = None
        self.tokenizer = None
        self.initialize_model()
        # 集成ForesightPipeline
        self.pipeline = ForesightPipeline(self.model, self.tokenizer, self.args)

    def initialize_model(self):
        """Initialize the language model and tokenizer"""
        model_path = self._get_model_path()
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, trust_remote_code=True, max_length=32768)

        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = LLM(
            model=model_path,
            tensor_parallel_size=self.args.gpus,
            trust_remote_code=True,
            max_model_len=32768
        )

        np.random.seed(self.args.seed)

    def _get_model_path(self):
        """Get the appropriate model path"""
        return self.args.model_path

    def get_system_prompt(self, dataset_type=None):
        """
        Get the appropriate system prompt based on dataset type
        Args:
            dataset_type: Type of dataset (e.g., 'drop', 'pubmedqa', 'financebench', etc.)
        Returns:
            System prompt string
        """

        # === Define zero-shot prompts for each dataset ===
            zeroshot_map = {
                'pubmedqa': (
                ),
                'financebench': (
                ),
                'halueval': (
                ),
                'history': (
                ),
                'ragtruth': (
                ),
                'covidQA': (
                ),
            }

            default_zeroshot = (
            )

        fewshot_map = {
            "history": HISTORY_8_FEW_SHOT,
            "nfl": NFL_8_FEW_SHOT,
            "covidQA": covidQA_4_FEW_SHOT,
            "halueval":halueval_6_FEW_SHOT,
            "financebench": financebench_5_FEW_SHOT,
            "pubmedQA": pubmedQA_4_FEW_SHOT,
            "ragtruth": RAGTruth_5_FEW_SHOT,
        }

        if dataset_type is None:
            filename = os.path.basename(self.args.data_path).lower()
            for dtype in zeroshot_map.keys():
                if dtype in filename:
                    dataset_type = dtype
                    break
            else:
                return default_zeroshot

        if self.args.shot_mode == 'zeroshot':
            return zeroshot_map.get(dataset_type, default_zeroshot)
        
        else:  # fewshot 
            base_prompt = zeroshot_map.get(dataset_type)
            if base_prompt is None:
                base_prompt = default_zeroshot

            # 获取 few-shot 示例
            fewshot_examples = fewshot_map.get(dataset_type, "")
            if not fewshot_examples:
                return base_prompt

            system_prompt = (
                f"{base_prompt}\n\n"
                "I will give you some examples for reference:\n"
                f"{fewshot_examples}"
                "请根据以上示例，继续推理并给出完整答案。"
            )
            return system_prompt

    def cluster_responses(self, responses, advantages):
        """
        Cluster responses using TF-IDF and K-means with improved robustness
        Args:
            responses: List of response texts
            advantages: List of advantage values for each response
        Returns:
            Tuple of (clusters, cluster_info)
        """
        # Filter out empty responses
        valid_indices = [i for i, r in enumerate(responses) if r.strip()]
        if len(valid_indices) < self.args.step_beam_size:
            return None, {"state": "cannot cluster", "reason": "insufficient valid responses"}

        try:
            valid_responses = [responses[i] for i in valid_indices]
            
            # Check if responses are too similar for meaningful clustering
            if len(set(valid_responses)) <= 1:
                return None, {"state": "cannot cluster", "reason": "all responses identical"}
            
            # Vectorize responses with optimized preprocessing
            vectorizer = TfidfVectorizer(
                max_features=500,   # Reduced from 1000 to avoid overfitting
                min_df=1,          # Include all terms
                max_df=0.9,        # Slightly adjusted from 0.95
                ngram_range=(1, 1) # Use unigrams only for better stability
            )
            X = vectorizer.fit_transform(valid_responses)
            
            # Check if we have enough features for clustering
            if X.shape[1] < 2:
                return None, {"state": "cannot cluster", "reason": "insufficient features"}
            
            # Perform clustering with silhouette analysis
            kmeans = KMeans(n_clusters=self.args.cluster_num, random_state=42, n_init=10)
            kmeans.fit(X)
            
            # Calculate silhouette score to assess clustering quality
            if len(valid_responses) > 2 and len(set(kmeans.labels_)) > 1:
                silhouette_avg = silhouette_score(X, kmeans.labels_)
                if silhouette_avg < 0.05:  # Relaxed threshold from 0.1 to 0.05
                    return None, {"state": "cannot cluster", "reason": f"poor clustering quality (silhouette: {silhouette_avg:.3f})"}
            else:
                silhouette_avg = 0.0

            # Group responses by cluster
            clusters = [[] for _ in range(self.args.cluster_num)]
            for idx, label in enumerate(kmeans.labels_):
                clusters[label].append(valid_indices[idx])

            return clusters, {
                "state": "success",
                "cluster_sizes": [len(c) for c in clusters],
                "silhouette_score": silhouette_avg,
                "feature_count": X.shape[1]
            }

        except Exception as e:
            return None, {"state": "fail", "error": str(e)}

    def select_response(self, responses, logprobs, advantages):
        """Select final response based on strategy with robustness and semantic bonus"""
        if self.args.strategy == "cluster":
            # filter out empty responses
            valid_indices = [idx for idx, r in enumerate(responses) if r.strip() != '']
            if len(valid_indices) == 0:
                print('all responses in the final generation are empty, use -adv no replace')
                weights = softmax([-adv/TEMPERATURE for adv in advantages])
                return np.random.choice(len(advantages), p=weights)

            if len(valid_indices) < self.args.step_beam_size:
                print('valid responses are less than step_beam_size, use adv no replace')
                weights = softmax([adv/TEMPERATURE for adv in advantages])
                return np.random.choice(len(advantages), p=weights)

            try:
                # prepare cluster data (compress to valid items)
                valid_responses = [responses[i] for i in valid_indices]
                valid_advantages = [advantages[i] for i in valid_indices]

                # Use improved clustering with robustness checks
                clusters, cluster_info = self.cluster_responses(valid_responses, valid_advantages)
                
                if clusters is None:
                    print(f"Clustering failed: {cluster_info.get('reason', 'unknown')}, using advantage-based selection with semantic bonus")
                    # Fallback with semantic quality bonus
                    enhanced_advantages = []
                    for resp, base in zip(valid_responses, valid_advantages):
                        bonus = 0
                        if len(resp.strip()) > 50:
                            bonus += 0.2
                        if any(ind in resp.lower() for ind in ['answer:', 'therefore', 'thus', 'hence']):
                            bonus += 0.1
                        if any(ind in resp.lower() for ind in ['=', '+', '-', '*', '/']):
                            bonus += 0.1
                        enhanced_advantages.append(base + bonus)
                    weights = softmax([adv/TEMPERATURE for adv in enhanced_advantages])
                    selected_index_in_valid = np.random.choice(len(weights), p=weights)
                    return valid_indices[selected_index_in_valid]

                # Select from the largest cluster
                cluster_sizes = [len(c) for c in clusters]
                largest_idx = int(np.argmax(cluster_sizes))
                selected_cluster = clusters[largest_idx]  # indices in valid_responses

                # Enhanced advantages in selected cluster
                enhanced_advantages = []
                for ddi in selected_cluster:
                    resp = valid_responses[ddi]
                    base_adv = valid_advantages[ddi]
                    bonus = 0
                    if len(resp.strip()) > 50:
                        bonus += 0.2
                    if any(ind in resp.lower() for ind in ['answer:', 'therefore', 'thus', 'hence']):
                        bonus += 0.1
                    if any(ind in resp.lower() for ind in ['=', '+', '-', '*', '/']):
                        bonus += 0.1
                    enhanced_advantages.append(base_adv + bonus)

                weights = softmax([adv/TEMPERATURE for adv in enhanced_advantages])
                selected_index_in_cluster = np.random.choice(len(weights), p=weights)
                selected_index_in_valid = selected_cluster[selected_index_in_cluster]
                selected_index_final = valid_indices[selected_index_in_valid]

                print(f'Selected from largest cluster (size: {len(selected_cluster)}) with enhanced advantage')
                return selected_index_final

            except Exception as e:
                print(f'Cannot select response based on cluster: {e}, using advantage-based selection with semantic bonus')
                enhanced_advantages = []
                for resp, base in zip(responses, advantages):
                    bonus = 0
                    if resp.strip() and len(resp.strip()) > 50:
                        bonus += 0.2
                    if any(ind in resp.lower() for ind in ['answer:', 'therefore', 'thus', 'hence']):
                        bonus += 0.1
                    if any(ind in resp.lower() for ind in ['=', '+', '-', '*', '/']):
                        bonus += 0.1
                    enhanced_advantages.append(base + bonus)
                weights = softmax([adv/TEMPERATURE for adv in enhanced_advantages])
                return np.random.choice(len(weights), p=weights)

        else:
            raise ValueError(f"Unknown strategy: {self.args.strategy}")

    def process_example(self, example, system_prompt):
        """
        Process a single example through the phi-decoding pipeline
        """
        # 直接调用pipeline的每一步
        token_stats = {"input": 0, "output": 0}
        rollout_stats = {"total": 0, "saved": 0}
        previous_steps = ["The reasoning steps are:\n\n" for _ in range(self.args.step_beam_size)]
        previous_values = [0.0 for _ in range(self.args.step_beam_size)]
        traj_info = {
            'question_idx': example.get('id', 0),
            'passage': example['passage'],
            'question': example['question'],
            'ground_truth': example.get('answer'),
            'foresight_part': [],
            'final_part': {},
            'config': {
                'num_rollout': self.args.num_rollout,
                'num_foresight': self.args.num_foresight,
                'step_beam_size': self.args.step_beam_size,
                'strategy': self.args.strategy,
                'width_pruning_strategy': self.args.width_pruning_strategy,
                'depth_pruning_strategy': self.args.depth_pruning_strategy,
                'threshold': self.args.threshold,
                'sigma_rate': self.args.sigma_rate,
                'cluster_num': self.args.cluster_num
            }
        }

        # Multi-step reasoning using ForesightPipeline
        for step in range(self.args.num_foresight):
            # 1. Candidate Generation
            responses, logprobs, advantages, token_nums = self.pipeline.stage1_generate(
                example, system_prompt, previous_steps, previous_values, token_stats, rollout_stats
            )
            # 2. Pruning & Local Enhancement
            filtered_responses, filtered_logprobs, filtered_advantages, keep_indices = self.pipeline.stage2_refine(
                responses, logprobs, advantages, self.args.num_rollout, token_stats, rollout_stats
            )
            # 3. Completion + Clustering Selection
            completed_responses, completed_logprobs, completed_advantages, selected, stop_foresight = self.pipeline.stage3_select(
                example, system_prompt, previous_steps, keep_indices, filtered_responses, previous_values,
                self.args.num_rollout, token_stats, rollout_stats
            )

            # 更新状态
            previous_steps = [previous_steps[keep_indices[idx]//self.args.num_rollout] +
                              responses[keep_indices[idx]] + "\n" for idx in selected]
            previous_values = [completed_logprobs[idx] for idx in selected]

            # Early stopping
            step_results = {
                "trajectories": completed_responses,
                "steps": [keep_indices[idx] for idx in selected],
                "logprobs": completed_logprobs,
                "advantages": completed_advantages,
                "stop_foresight": stop_foresight
            }
            if self.pipeline._should_stop_early(step_results, step):
                break

        # Final response
        final_result = self.pipeline._generate_final_response(
            example, system_prompt, previous_steps, previous_values, token_stats, rollout_stats, traj_info
        )
        traj_info['token_num'] = token_stats["input"] + token_stats["output"]

        return {
            "response": final_result["response"],
            "token_stats": token_stats,
            "rollout_stats": rollout_stats,
            "trajectories": {
                "final": final_result["trajectories"]
            },
            "traj_info": traj_info
        }

class ForesightPipeline:
    def __init__(self, model, tokenizer, args):
        self.model = model
        self.tokenizer = tokenizer
        self.args = args
        self.llm_refine = default_llm_refine  # 赋予默认实现

    # ----------- Stage 1: Candidate Generation -----------
    def stage1_generate(self, example, system_prompt, previous_steps, previous_values, token_stats, rollout_stats):

        all_inputs = []
        for beam_idx in range(self.args.step_beam_size):
            chat = [{"role": "system", "content": system_prompt},
                    {"role": "user", "content": "<PLACEHOLDER_INPUT>"}]
            inputs = self.tokenizer.apply_chat_template(chat, tokenize=False).rstrip()
            inputs = inputs + previous_steps[beam_idx]
            token_stats["input"] += len(self.tokenizer(inputs)["input_ids"])
            all_inputs.append(inputs)

        sampling_params = {"max_tokens": 1024, "n": self.args.num_rollout,
                           "temperature": 0.4, "stop": ["<end>"]}
        outputs = self.model.generate(all_inputs, sampling_params)
        rollout_stats["total"] += self.args.num_rollout * self.args.step_beam_size

        responses, logprobs, advantages, token_nums = [], [], [], []
        for beam_idx, beam_outputs in enumerate(outputs):
            for output in beam_outputs.outputs:
                resp = output.text.strip()
                logp = output.cumulative_logprob / (len(output.token_ids) + 1e-8)
                adv = logp - previous_values[beam_idx]

                responses.append(resp)
                logprobs.append(logp)
                advantages.append(adv)
                token_nums.append(len(output.token_ids))
                token_stats["output"] += len(output.token_ids)

        return responses, logprobs, advantages, token_nums

    # ----------- Stage 2: Pruning & Local Enhancement -----------
    def stage2_refine(self, responses, logprobs, advantages, num_rollout, token_stats, rollout_stats):
        keep_indices = []

        if self.args.width_pruning_strategy == "low_sigma":
            mean, std = np.mean(logprobs), np.std(logprobs)
            for idx, logp in enumerate(logprobs):
                if logp > mean - self.args.sigma_rate * std:
                    keep_indices.append(idx)

            for idx, resp in enumerate(responses):
                if idx not in keep_indices:
                    score = 0
                    if len(resp) > 50: score += 0.2
                    if any(x in resp.lower() for x in ['answer:', 'thus', 'hence']): score += 0.1
                    if any(x in resp for x in ['=', '+', '-', '*', '/']): score += 0.1
                    if score >= 0.3: keep_indices.append(idx)

        if len(keep_indices) < self.args.step_beam_size:
            weights = np.exp(np.array(logprobs)/TEMPERATURE)
            weights = weights / weights.sum()
            available = [i for i in range(len(responses)) if i not in keep_indices]
            if available:
                add = np.random.choice(
                    available,
                    size=self.args.step_beam_size - len(keep_indices),
                    p=[weights[i]/weights[available].sum() for i in available],
                    replace=False
                ).tolist()
                keep_indices.extend(add)

        keep_indices = sorted(set(keep_indices))

        # Multi-segment refinement
        for idx in keep_indices.copy():
            response = responses[idx]

            # Assume segments are pre-split; here simple split by periods
            segments = [seg.strip() for seg in response.split('.') if seg.strip()]
            refined_segments = []
            discard_response = False

            for seg in segments:
                segment_tokens = seg.split()
                hidden_states = token_stats.get(idx, {}).get("hidden_states", [np.zeros(768)]*len(segment_tokens))
                halu_scores = token_stats.get(idx, {}).get("halu_scores", [0.5]*len(segment_tokens))
                global_context = np.mean(hidden_states, axis=0)

                refinement_result = segment_refinement(
                    segment_tokens, hidden_states, halu_scores,
                    global_context, llm_refine=self.llm_refine
                )

                if refinement_result["status"] == "discarded":
                    discard_response = True
                    break
                else:
                    refined_segments.append(" ".join(refinement_result["tokens"]))

            if discard_response:
                keep_indices.remove(idx)
            else:
                responses[idx] = '. '.join(refined_segments)

        rollout_stats["saved"] += (self.args.step_beam_size * self.args.num_rollout - len(keep_indices))

        filtered_responses = [responses[i] for i in keep_indices]
        filtered_logprobs = [logprobs[i] for i in keep_indices]
        filtered_advantages = [advantages[i] for i in keep_indices]

        return filtered_responses, filtered_logprobs, filtered_advantages, keep_indices

    # ----------- Stage 3: Completion + Clustering Selection -----------
    def stage3_select(self, example, system_prompt, previous_steps, keep_indices,
                      filtered_responses, previous_values, num_rollout, token_stats, rollout_stats):
        # --- completion ---
        completion_inputs = []
        for idx in range(len(keep_indices)):
            resp = filtered_responses[idx]
            beam_idx = keep_indices[idx] // num_rollout
            chat = [{"role": "system", "content": system_prompt},
                    {"role": "user", "content": previous_steps[beam_idx] + resp}]
            inputs = self.tokenizer.apply_chat_template(chat, tokenize=False).rstrip()
            completion_inputs.append(inputs)
            token_stats["input"] += len(self.tokenizer(inputs)["input_ids"])

        sampling_params = {"max_tokens": 1024, "n": 1, "stop": ["<end>"]}
        outputs = self.model.generate(completion_inputs, sampling_params)
        rollout_stats["total"] += len(completion_inputs)

        completed_responses, completed_logprobs, completed_advantages = [], [], []
        for idx, out in enumerate(outputs):
            o = out.outputs[0]
            resp = o.text.strip()
            logp = o.cumulative_logprob / (len(o.token_ids) + 1e-8)
            beam_idx = keep_indices[idx] // num_rollout
            adv = logp - previous_values[beam_idx]

            completed_responses.append(resp)
            completed_logprobs.append(logp)
            completed_advantages.append(adv)
            token_stats["output"] += len(o.token_ids)

        # --- clustering & selection ---
        clusters, cluster_info = self.cluster_responses(completed_responses, completed_advantages)
        if clusters is None:
            weights = np.exp(np.array(completed_advantages)/TEMPERATURE)
            weights = weights / weights.sum()
            selected = np.random.choice(len(weights), size=self.args.step_beam_size,
                                        p=weights, replace=False).tolist()
            stop_foresight = False
        else:
            weights = np.ones(len(completed_responses)) / len(completed_responses)
            selected = np.random.choice(len(weights), size=self.args.step_beam_size,
                                        p=weights, replace=False).tolist()
            stop_foresight = True

        return completed_responses, completed_logprobs, completed_advantages, selected, stop_foresight

    def _process_step(self, example, system_prompt, previous_steps, previous_values,
                      token_stats, rollout_stats, traj_info):
        # stage1
        r1, l1, a1, t1 = self.stage1_generate(example, system_prompt,
                                              previous_steps, previous_values,
                                              token_stats, rollout_stats)

        # stage2
        r2, l2, a2, keep_indices = self.stage2_refine(r1, l1, a1,
                                                      self.args.num_rollout,
                                                      token_stats, rollout_stats)

        # stage3
        completed_responses, completed_logprobs, completed_advantages, selected, stop = \
            self.stage3_select(example, system_prompt, previous_steps, keep_indices,
                               r2, previous_values, self.args.num_rollout,
                               token_stats, rollout_stats)

        return {
            "next_steps": [previous_steps[keep_indices[idx]//self.args.num_rollout] +
                           r1[keep_indices[idx]] + "\n" for idx in selected],
            "next_values": [completed_logprobs[idx] for idx in selected],
            "trajectories": completed_responses,
            "steps": [keep_indices[idx] for idx in selected],
            "logprobs": completed_logprobs,
            "advantages": completed_advantages,
            "stop_foresight": stop
        }

    def _should_stop_early(self, step_results, current_step):
        """Check if early stopping conditions are met with enhanced quality assessment"""
        if current_step < self.args.least_foresight_num:
            return False

        # Check if all responses are identical (original condition)
        just_stop = True
        first_response = step_results["trajectories"][0]
        for response in step_results["trajectories"][1:]:
            if response != first_response:
                just_stop = False
                break

        if just_stop:
            print(
                f'Early stopping at depth {current_step} (all responses are the same)')
            return True

        # Enhanced quality assessment: check if average advantage is too low
        if hasattr(step_results, 'advantages') and step_results['advantages']:
            avg_advantage = np.mean(step_results['advantages'])
            if avg_advantage < -2.0:  # Threshold for poor reasoning quality
                print(f'Early stopping at depth {current_step} (poor reasoning quality, avg advantage: {avg_advantage:.3f})')
                return True

        # Check for repetitive patterns (high similarity between responses)
        if len(step_results["trajectories"]) > 1:
            from difflib import SequenceMatcher
            similarity_scores = []
            for i in range(len(step_results["trajectories"])):
                for j in range(i+1, len(step_results["trajectories"])):
                    similarity = SequenceMatcher(None, 
                                            step_results["trajectories"][i], 
                                            step_results["trajectories"][j]).ratio()
                    similarity_scores.append(similarity)
            
            if similarity_scores and np.mean(similarity_scores) > 0.8:  # High similarity threshold
                print(f'Early stopping at depth {current_step} (high response similarity: {np.mean(similarity_scores):.3f})')
                return True

        if self.args.depth_pruning_strategy == "cluster":
            # Check if responses are becoming similar
            if step_results["stop_foresight"]:
                print(
                    f'Early stopping at depth {current_step} (max cluster ratio >= args.threshold)')
                return True

        return False

    def _generate_final_response(self, example, system_prompt, previous_steps, previous_values, token_stats, rollout_stats, traj_info):
        """Generate final response after multi-step reasoning"""
        # Prepare input for each beam
        all_inputs = []
        for beam_idx in range(self.args.step_beam_size):
            chat = self._prepare_chat_template(example, system_prompt)
            chat[-1]["content"] = previous_steps[beam_idx]

            inputs = self.tokenizer.apply_chat_template(
                chat,
                tokenize=False
            ).rstrip(self.tokenizer.eos_token).rstrip()

            token_stats["input"] += len(self.tokenizer(inputs)["input_ids"])
            all_inputs.append(inputs)

        # parallel generate all beam responses
        sampling_params = SamplingParams(
            max_tokens=3000,
            n=1,
            logprobs=0,
            stop=["<end_of_reasoning>"]
        )
        outputs = self.model.generate(all_inputs, sampling_params)

        rollout_stats["total"] += self.args.step_beam_size

        # Collect all response results
        all_responses = []
        all_logprobs = []
        all_advantages = []
        all_combined_responses = []

        for beam_idx, beam_outputs in enumerate(outputs):
            output = beam_outputs.outputs[0]
            response = output.text.strip()
            logprob = output.cumulative_logprob / \
                (len(output.token_ids) + 1e-8)
            advantage = logprob - previous_values[beam_idx]

            # Combine previous_steps and new response
            combined_response = previous_steps[beam_idx] + response
            all_combined_responses.append(combined_response)
            all_responses.append(response)
            all_logprobs.append(logprob)
            all_advantages.append(advantage)
            token_stats["output"] += len(output.token_ids)

        # Debug: Print final stage responses
        print(f"\n=== Final Stage Responses (Total: {len(all_responses)}) ===")
        for i, response in enumerate(all_responses):
            print(f"Final {i}: {response}")
        print("=" * 50)

        # Select final response
        selected_idx = self.select_response(
            all_responses,
            all_logprobs,
            all_advantages
        )

        # Debug: Print final selected response
        print(f"\n=== Final Selected Response ===")
        #...

        # Record final results

        return {
            "response": previous_steps[selected_idx] + all_responses[selected_idx],
            # "response_in_the_final_generation": all_responses[selected_idx],
            "trajectories": {
                "responses": all_responses,
                "logprobs": all_logprobs,
                "advantages": all_advantages,
                "selected_idx": selected_idx
            }
        }

    def _prepare_chat_template(self, example, system_prompt):
        """
        Prepare chat template based on dataset type
        Args:
            example: Input example
            system_prompt: System prompt
        Returns:
            List of chat messages
        """
        passage = example['passage']
        question = example['question']
        chat = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': 'Passage: ' + passage + '\nQuestion: ' + question + '\nPlease directly follow the previous reasoning steps (if provided) and generate the remaining ones.\n'},
            {'role': 'assistant', 'content': ''}
        ]
        return chat

    def _prepare_chat_template_for_first_stage(self, example, system_prompt):
        """
        Prepare chat template based on dataset type
        Args:
            example: Input example
            system_prompt: System prompt
        Returns:
            List of chat messages
        """
        passage = example['passage']
        question = example['question']
        chat = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': 'Passage: ' + passage + '\nQuestion: ' + question + '\nPlease directly follow the previous reasoning steps (if provided) and generate the remaining ones.\n'},
            {'role': 'assistant', 'content': ''}
        ]
        return chat


def main():
    """Main execution function"""
    args = parse_arguments()
    decoder = PhiDecoder(args)

    with open(args.data_path) as f:
        test_data = json.load(f)
    max_num = len(test_data) if args.max_examples == -1 else min(len(test_data), args.max_examples)
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.time_path, exist_ok=True)
    output_path = os.path.join(args.output_dir, f"{args.file_name}.json")

    # Record start time
    start_time = time.time()

    # Statistics
    total_stats = {
        "total_rollouts": 0,
        "saved_rollouts": 0,
        "input_tokens": 0,
        "output_tokens": 0
    }

    # Used to store all trajectory information
    all_traj_info = []

    # Process each test example
    for i, example in enumerate(test_data[:max_num]):
        print(f"正在处理第{i+1}/{max_num}个样本，问题：{example.get('question', '')[:50]}")
        # try:
        # Generate system prompt
        system_prompt = decoder.get_system_prompt()

        # Process example
        result = decoder.process_example(example, system_prompt)

        # Update statistics
        total_stats["total_rollouts"] += result["rollout_stats"]["total"]
        total_stats["saved_rollouts"] += result["rollout_stats"]["saved"]
        total_stats["input_tokens"] += result["token_stats"]["input"]
        total_stats["output_tokens"] += result["token_stats"]["output"]

        # Add trajectory information
        result["traj_info"]["question_idx"] = i
        all_traj_info.append(result["traj_info"])

        # Prepare output result
        output_result = {
            "id": i,
            "question": example["question"],
            "passage": example["passage"],
            "ground_truth": example.get("answer"),
            "response": result["response"],
            "response_all_beams": result["trajectories"]["final"]["responses"] if "final" in result["trajectories"] else []
        }

        # Write result to main output file
        with open(output_path, "a") as f:
            f.write(json.dumps(output_result) + "\n")

        print(
            f'output_token_num_for_question{i}: {result["token_stats"]["output"]}')
        print(
            f'input_token_num_for_question{i}: {result["token_stats"]["input"]}')
        print(f'all_output_token_num: {total_stats["output_tokens"]}')
        print(f'all_input_token_num: {total_stats["input_tokens"]}')

        # Save trajectory information
        if args.record_process:
            traj_path = os.path.join(
                args.time_path, f"TRAJ_INFO-{args.file_name}.json")
            with open(traj_path, "w") as f:
                json.dump(all_traj_info, f, indent=2)

    # Calculate total time
    end_time = time.time()
    time_span = end_time - start_time

    # Save time information to separate file
    time_info_path = os.path.join(args.time_path, f"{args.file_name}.txt")
    with open(time_info_path, "w") as f:
        f.write(f'time:  {time_span}\n')
        #...
    print('total rollouts: ', total_stats["total_rollouts"])
    print('saved rollouts: ', total_stats["saved_rollouts"])
    print('all_output_token_num: ', total_stats["output_tokens"])
    print('all_input_token_num: ', total_stats["input_tokens"])


if __name__ == "__main__":
    main()


