# demo_iterative_update_activation.py

################################################################################
# 1. Imports
################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import jvp, vjp # Using jvp and vjp for matrix-free products
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from tqdm import tqdm
import re
import copy
import pandas as pd
import gc # 新增: 用于清理CUDA缓存

# Import for SciPy CG solver and NumPy
try:
    from scipy.sparse.linalg import cg, LinearOperator
    import numpy as np
except ImportError:
    print("To use the SciPy solver, please install it first by running `pip install scipy`.")
    exit()

# Import for dataset loading
try:
    from datasets import Dataset, get_dataset_config_names, load_dataset, concatenate_datasets
except ImportError:
    print("To load the dataset, please install it first by running `pip install datasets`.")
    exit()

try:
    from math_verify.errors import TimeoutException
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    print("To use Math-Verify, please install it first by running `pip install math-verify`.")
    exit()

################################################################################
# 2. Reward Function and Data Loading
################################################################################

# String normalization functions from lm-evaluation-harness
def remove_boxed(s):
    if s is None: return None
    if "\\boxed " in s:
        left = "\\boxed "
        return s[s.find(left) + len(left):].split("$")[0]
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except (AssertionError, IndexError):
        return None

def last_boxed_only_string(string):
    if string is None: return None
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None
    
    if (idx + 6 < len(string)) and string[idx+6] == ' ': # handles "\boxed "
        return string[idx:]

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    return None if right_brace_idx is None else string[idx : right_brace_idx + 1]

def strip_string(string):
    string = str(string)
    string = string.replace("\n", "")
    string = string.replace("\\!", "")
    string = string.replace("\\\\", "\\")
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")
    string = string.replace("\\$", "")
    string = string.replace("\\%", "")
    string = string.replace("\%", "")
    if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
        string = string.split("=")[1]
    string = string.replace(" ", "")
    return string

def is_equiv(str1, str2):
    if str1 is None and str2 is None: return True
    if str1 is None or str2 is None: return False
    try:
        return strip_string(str1) == strip_string(str2)
    except Exception:
        return str1 == str2

def compute_score(solution_str, ground_truth) -> float:
    retval = 0.0
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if is_equiv(answer, ground_truth):
                retval = 1.0
    except Exception:
        pass
    return retval

class MathVerifyRewardModel:
    def __init__(self, ground_truth_answer: str):
        self.ground_truth_answer = ground_truth_answer

    def get_reward(self, response_text: str) -> float:
        return 1.0 if compute_score(response_text, self.ground_truth_answer) == 1.0 else -1.0

################################################################################
# math
################################################################################
instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
def get_math_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/MATH-benchmark", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_math_500_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--HuggingFaceH4--MATH-500", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_AIME25_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--math-ai--aime25", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None 

def split_by_last_separator(s: str) -> list:
    """
    用最后一个 '.' 或最后一个 '\n' 来分割字符串。

    Args:
        s: 输入的字符串。

    Returns:
        一个包含两个部分的列表。如果找不到分隔符，则返回只包含原始字符串的列表。
    """
    # 1. 找到两个分隔符最后出现的位置
    last_dot_index = s.rfind('.')
    last_newline_index = s.rfind('\n')

    # 2. 确定最终的分割点（取索引值更大的那个）
    # max() 函数可以很方便地处理其中一个或两个都找不到（返回-1）的情况
    split_index = max(last_dot_index, last_newline_index)

    # 3. 如果两个分隔符都找不到，split_index 会是 -1
    if split_index == -1:
        return [s]  # 返回包含原字符串的列表，行为与 split() 类似
    else:
        # 使用切片进行分割
        first_part = s[:split_index]
        second_part = s[split_index + 1:] # +1 是为了跳过分隔符本身
        return [first_part, second_part]

def get_mathdial_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--eth-nlped--mathdial", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following, 
            'solution': split_by_last_separator(x['ground_truth'].strip())[0],
            'answer': split_by_last_separator(x['ground_truth'].strip())[0]
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None     

def get_gsm8k_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/gsm8k", 'main', split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': x['answer'].rsplit('####', 1)[0],
            'answer': x['answer'].rsplit('####', 1)[1]
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None   


################################################################################
# general task
################################################################################

def get_gpqa_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--fingertap--GPQA-Diamond", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None   



def get_mmlu_redux_questions():
    
    try:
        dataset_name = "/data_train/kitwei/grpo/dataset/datasets--edinburgh-dawg--mmlu-redux"
        subset_names = get_dataset_config_names(dataset_name)
        
        all_splits = []
        for name in subset_names:
            # Load each subset individually
            subset_dict = load_dataset(dataset_name, name)
            # Add all its splits (train, test, etc.) to our list
            for split in subset_dict.values():
                all_splits.append(split)
        
        # Merge everything into a single dataset
        data = concatenate_datasets(all_splits)

        def format_example(x):
            """
            A helper function to format each question, its choices, and the answer.
            """
            choices = x['choices']
            # Generate letters ('A', 'B', 'C', ...) for the options
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]

            # 1. 'problem': Combine the question with the formatted choices
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)

            # 2. 'answer': Convert the answer index to its corresponding letter
            answer_index = x['answer']
            answer_letter = option_letters[answer_index] if 0 <= answer_index < len(choices) else None
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",  # As requested, solution is always empty
                'answer': answer_letter
            }

        data = data.map(format_example) 
        
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset name 'edinburgh-dawg/mmlu-redux' is correct and the 'datasets' library is installed.")
        return None

def get_superGPQA_questions():
    
    try:
        dataset_name = "/data_train/kitwei/grpo/dataset/datasets--m-a-p--SuperGPQA"

        subset_names = get_dataset_config_names(dataset_name)
        
        all_splits = []
        for name in subset_names:
            # Load each subset individually
            subset_dict = load_dataset(dataset_name, name)
            # Add all its splits (train, test, etc.) to our list
            for split in subset_dict.values():
                all_splits.append(split)
        
        # Merge everything into a single dataset
        data = concatenate_datasets(all_splits)

        def format_example(x):
            """
            A helper function to format each question, its choices, and the answer.
            """
            choices = x['options']
            # Generate letters ('A', 'B', 'C', ...) for the options
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]

            # 1. 'problem': Combine the question with the formatted choices
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)

            # 2. 'answer': Convert the answer index to its corresponding letter
            answer_letter = x['answer_letter']
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",  # As requested, solution is always empty
                'answer': answer_letter
            }

        data = data.map(format_example) 
        
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset name 'SuperGPQA' is correct and the 'datasets' library is installed.")
        return None

################################################################################
# 多语言
################################################################################

def transform_to_qa_format(dataset: Dataset) -> Dataset:
    """
    将宽格式的多语言数据集转换为长格式的问答数据集。
    """
    
    # 确定哪些列是问题(语言列), 哪些是答案
    # 我们假设 'answer' 是答案列, 其他所有列都是不同语言的问题
    language_columns = [col for col in dataset.column_names if col != 'answer']
    
    # 定义一个生成器函数, 用于逐一产出新的 (question, answer) 对
    def generate_qa_pairs():
        for example in dataset:
            answer = example['answer']
            for lang in language_columns:
                question = example[lang]
                # 确保问题不为空或null
                if question is not None and str(question).strip():
                    yield {
                        'language': lang,
                        'question': question,
                        'answer': answer
                    }

    # 从生成器创建新的数据集
    # 使用 pd.DataFrame 作为中介可以很好地处理结构
    df = pd.DataFrame(generate_qa_pairs())
    qa_dataset = Dataset.from_pandas(df)
    
    return qa_dataset

def get_M_IMO_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "M-IMO", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_MT_AIME24_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "MT-AIME2024", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_MT_MATH100_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "MT-MATH100", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

################################################################################
# 3. Core Algorithm Components
################################################################################

class ActivationModifier(nn.Module):
    def __init__(self, hidden_dim, device='cuda'):
        super().__init__()
        self.delta = nn.Parameter(torch.zeros(1, 1, hidden_dim, device=device))
        delta_param_count = self.delta.numel()
        print(f"Delta 参数量: {delta_param_count}")
    def forward(self, hidden_states):
        return hidden_states + self.delta

class PiStarCalculator:
    def __init__(self, tokenizer, beta=1.0, device='cuda'):
        self.tokenizer = tokenizer
        self.beta = beta
        self.device = device

    @torch.no_grad()
    def calculate_pi_star_logits(self, ref_logits: torch.Tensor, target_input_ids: torch.Tensor, labels: torch.Tensor, reward: float) -> torch.Tensor:
        log_probs_ref = F.log_softmax(ref_logits, dim=-1)
        
        nonzero_indices = (labels != -100).nonzero(as_tuple=False)
        if nonzero_indices.numel() == 0:
            return ref_logits.clone()
        answer_indices = nonzero_indices[:, 1]
        answer_token_ids = target_input_ids[0, answer_indices]

        selected_log_probs = log_probs_ref[0, answer_indices - 1, answer_token_ids]
        log_p_ref = selected_log_probs.sum()
        
        r_sequence = reward
        log_p_star = log_p_ref + (r_sequence / self.beta)
        
        total_log_bonus = log_p_star - log_p_ref
        per_token_log_bonus = total_log_bonus / len(answer_indices)
        
        pi_star_logits = ref_logits.clone()
        for i, token_idx in enumerate(answer_indices):
            correct_token_id = target_input_ids[0, token_idx]
            pi_star_logits[0, token_idx - 1, correct_token_id] += per_token_log_bonus
        return pi_star_logits

def calculate_delta_for_hidden_state_cg(model, activation_modifier, hidden_states, logits, pi_star_logits, labels, cg_steps=10, cg_tolerance=1e-10) -> torch.Tensor | None:
    if logits.shape[0] > 1:
        raise NotImplementedError("This function currently only supports batch_size=1.")

    nonzero_indices = (labels != -100).nonzero(as_tuple=False)
    if nonzero_indices.numel() == 0:
        return None
    answer_token_indices = nonzero_indices[:, 1]

    device = model.device
    theta_delta = activation_modifier.delta.flatten().detach()

    def F_policy_fn(delta_input):
        delta_reshaped = delta_input.view(1, 1, -1)
        all_policies = []
        for t in answer_token_indices:
            hidden_state_t = hidden_states[0, t, :].detach().unsqueeze(0).unsqueeze(0)
            modified_hidden_state = hidden_state_t + delta_reshaped
            new_logits = model.lm_head(modified_hidden_state).squeeze(0).squeeze(0)
            all_policies.append(F.softmax(new_logits, dim=-1))
        return torch.cat(all_policies)

    def A_prime_mvp(p_vec_np):
        p_vec = torch.from_numpy(p_vec_np).to(device)
        with torch.no_grad():
            _, jvp_result = jvp(F_policy_fn, (theta_delta,), (p_vec,))
            _, vjp_fn = vjp(F_policy_fn, theta_delta)
            A_prime_p = vjp_fn(jvp_result)[0]
        return A_prime_p.cpu().numpy()

    policy_diffs = []
    for t in answer_token_indices:
        pi_star_t = F.softmax(pi_star_logits[0, t, :], dim=-1)
        pi_theta_t = F.softmax(logits[0, t, :], dim=-1)
        policy_diffs.append(pi_star_t - pi_theta_t)
    b = torch.cat(policy_diffs)
    
    _, vjp_fn_b = vjp(F_policy_fn, theta_delta)
    b_prime = vjp_fn_b(b)[0]
    
    print("使用 SciPy 共轭梯度法 (CG) 求解 Δδ...")
    
    n_params = theta_delta.numel()
    b_prime_np = b_prime.detach().cpu().numpy()
    
    A_operator = LinearOperator((n_params, n_params), matvec=A_prime_mvp, dtype=b_prime_np.dtype)
    
    delta_update_flat_np, info = cg(A_operator, b_prime_np, maxiter=cg_steps, rtol=cg_tolerance)
    
    if info != 0:
        print(f"  [警告] CG 未在 {cg_steps} 步内收敛 (info={info})。")

    delta_update = torch.from_numpy(delta_update_flat_np).to(device).view(1, 1, -1)
    
    return delta_update

################################################################################
# 4. Iterative Problem Solving and Main Evaluation Loop
################################################################################

def jaccard_similarity(text1, text2):
    """Calculates the Jaccard similarity between two texts."""
    set1 = set(text1.split())
    set2 = set(text2.split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0

def solve_problem_iteratively(problem_data, original_base_model, tokenizer, device):
    math_problem = problem_data['problem']
    standard_answer = problem_data['answer']
    MAX_INPUT_LENGTH = 4096
    frozen_model = copy.deepcopy(original_base_model)
    frozen_model.eval()
    for param in frozen_model.parameters():
        param.requires_grad = False

    activation_modifier = ActivationModifier(frozen_model.config.hidden_size, device=device)
    
    reward_model = MathVerifyRewardModel(ground_truth_answer=standard_answer)
    pi_star_calculator = PiStarCalculator(tokenizer, beta=1.0, device=device)
    
    messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    base_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    
    captured_states = {}
    def forward_hook(module, input, output):
        captured_states['unmodified_hidden_states'] = output[0] if isinstance(output, tuple) else output
        return activation_modifier(captured_states['unmodified_hidden_states'])

    max_turns = 10
    result = [False, False, 11, 11]
    for turn in range(1, max_turns + 1):
        print("\n" + "-"*25 + f" 第 {turn}/{max_turns} 轮迭代 " + "-"*25)
        
        context_text_loop = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        context_ids = tokenizer(context_text_loop, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True)['input_ids']

        base_context_text_loop = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=True)
        base_context_ids = tokenizer(base_context_text_loop, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True)['input_ids']
        
        hook_handle = frozen_model.model.norm.register_forward_hook(forward_hook)
        with torch.no_grad():
            generated_ids = frozen_model.generate(
                input_ids=context_ids.to(device),
                max_new_tokens=512, top_p=0.9, temperature=1.0,
                pad_token_id=tokenizer.eos_token_id
            )
            generated_text = tokenizer.decode(generated_ids[0, context_ids.shape[1]:], skip_special_tokens=True)

            # [NEW] Generate from the original base model for comparison
            base_model_generated_ids = original_base_model.generate(
                input_ids=base_context_ids.to(device),
                max_new_tokens=512,  top_p=0.9, temperature=1.0,
                pad_token_id=tokenizer.eos_token_id
            )
            base_model_generated_text = tokenizer.decode(base_model_generated_ids[0, base_context_ids.shape[1]:], skip_special_tokens=True)

        hook_handle.remove()
        
        # [NEW] Calculate and print similarity
        similarity = jaccard_similarity(generated_text, base_model_generated_text)
        
        print(f"模型回答: '{generated_text}'")
        print(f"原始模型回答: '{base_model_generated_text}'")
        print(f"回答相似度 (Jaccard): {similarity:.4f}")
        
        reward = reward_model.get_reward(generated_text)
        base_reward = reward_model.get_reward(base_model_generated_text)
        print(f"获得奖励: {reward}")
        print(f"base model 获得奖励 {base_reward}")

        if reward == 1.0:
            print(f"  [成功] 在第 {turn} 轮回答正确！")
            result[0] = True
            result[2] = turn
            break
        
        if base_reward == 1.0:
            print(f"  [成功] 在第 {turn} 轮回答正确！")
            result[1] = True
            if result[3] == 11:
                result[3] = turn
            base_messages.append({"role": "assistant", "content": base_model_generated_text})
            base_messages.append({"role": "user", "content": "Right answer"})
        else:
            base_messages.append({"role": "assistant", "content": base_model_generated_text})
            base_messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})

        messages.append({"role": "assistant", "content": generated_text})
        messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})


        target_input_ids = generated_ids.to(device)
        initial_context_len = context_ids.shape[1]
        labels = target_input_ids.clone()
        labels[0, :initial_context_len] = -100
        
        # Get reference logits by running a forward pass on the *modified* model
        hook_handle = frozen_model.model.norm.register_forward_hook(forward_hook)
        with torch.no_grad():
            ref_outputs = frozen_model(target_input_ids)
            ref_logits = ref_outputs.logits
        hook_handle.remove()
        
        pi_star_logits = pi_star_calculator.calculate_pi_star_logits(ref_logits, target_input_ids, labels, reward)
        
        logits_pre = ref_logits
        # This forward pass is just to capture the unmodified hidden states
        hook_handle = frozen_model.model.norm.register_forward_hook(forward_hook)
        frozen_model(target_input_ids)
        hook_handle.remove()
        hidden_states_pre = captured_states['unmodified_hidden_states']

        # [NEW] Add diagnostic printing for policy difference
        with torch.no_grad():
            nonzero_indices_diag = (labels != -100).nonzero(as_tuple=False)
            if nonzero_indices_diag.numel() > 0:
                answer_indices_diag = nonzero_indices_diag[:, 1]
                pi_star = F.softmax(pi_star_logits[0, answer_indices_diag, :], dim=-1)
                pi_theta = F.softmax(logits_pre[0, answer_indices_diag, :], dim=-1)
                kl_div = F.kl_div(pi_theta.log(), pi_star, reduction='sum', log_target=False)
                tvd = 0.5 * torch.sum(torch.abs(pi_star - pi_theta))
                print(f"策略差异诊断: KL(π*||πθ)={kl_div.item():.4f}, TVD={tvd.item():.4f}")

        delta_theta_result = calculate_delta_for_hidden_state_cg(
            frozen_model, activation_modifier, hidden_states_pre, logits_pre, pi_star_logits, labels
        )

        if delta_theta_result is None:
            print("  [警告] 无法计算Δδ，跳过本轮更新。")
            continue
            
        delta_update = delta_theta_result
        with torch.no_grad():
            activation_modifier.delta.data += delta_update
        print(f"成功应用 Δδ (范数: {delta_update.norm():.4f}) 到隐藏层扰动向量。")
    
    print(f"  [失败] 达到最大迭代次数 {max_turns}，未能解决问题。")
    return result

def main():
    print("="*80)
    print("开始在 MATH 数据集上进行迭代式评估 (激活工程 + Hook 版本)")
    print("="*80)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #model_name = "/data_train/kitwei/model/Qwen2.5-0.5B-Instruct"
    model_name = "/data_train/kitwei/model/Qwen3-0.6B"
    #model_name = "/data_train/kitwei/model/Qwen2.5-7B-Instruct"
    #model_name = "/data_train/kitwei/model/Qwen3-8B"
    #model_name = "/data_train/kitwei/model/DeepSeek-R1-Distill-Qwen-7B"
    #model_name = "/data_train/kitwei/model/DeepSeek-R1-Distill-Llama-8B"
    #model_name = "/data_train/kitwei/model/Ministral-8B-Instruct-2410"
    #model_name = "/data_train/kitwei/model/Mistral-7B-Instruct-v0.3"
    print(f"使用设备: {device}")
    print(f"加载模型: {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    tokenizer.truncation_side = 'left'
    
    original_base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        #torch_dtype=torch.bfloat16,
        #attn_implementation="flash_attention_2",
        device_map="auto"
        )
    original_base_model.eval()
    
    print("模型加载完成。")

    print("\n--- 1. 加载数据集 ---")
    #math_dataset = get_math_questions()
    #math_dataset = get_math_500_questions()
    #math_dataset = get_AIME25_questions()
    #math_dataset = get_mathdial_questions()
    #math_dataset = get_gsm8k_questions()
    #math_dataset = get_gpqa_questions()
    #math_dataset = get_mmlu_redux_questions()
    #math_dataset = get_superGPQA_questions()
    #math_dataset = get_M_IMO_questions()
    #math_dataset = get_MT_MATH100_questions()
    math_dataset = get_MT_AIME24_questions()
    if math_dataset is None:
        return
    print(f"成功加载 {len(math_dataset)} 个 MATH 问题。")

    num_problems_to_evaluate = min(len(math_dataset), 1000) # For a quick demo
    problems_to_run = math_dataset.select(range(num_problems_to_evaluate))
    
    num_solved = 0
    turn = [0]*11
    base_num_solved = 0
    base_turn = [0]*11   
    
    for i, problem_data in enumerate(problems_to_run):
        print("\n" + "#"*30 + f" 评估问题 {i+1}/{num_problems_to_evaluate} " + "#"*30)
        print(f"问题: {problem_data['problem'][:100]}...")
        print(f"答案: {problem_data['answer']}")
        
        result = solve_problem_iteratively(problem_data, original_base_model, tokenizer, device)
        if result[0]:
            num_solved += 1
        if result[1]:
            base_num_solved += 1
        turn[result[2] - 1] += 1
        base_turn[result[3] - 1] += 1
        if device == 'cuda':
            torch.cuda.empty_cache()
            
    # --- Final Report ---
    print("\n" + "="*80)
    print("评估完成！")
    print("="*80)
    accuracy = (num_solved / num_problems_to_evaluate) * 100
    base_accuracy = (base_num_solved / num_problems_to_evaluate) * 100
    print(f"总共评估问题数: {num_problems_to_evaluate}")
    print(f"成功解决问题数: {num_solved}")
    print(f"10轮内解决率: {accuracy:.2f}%")
    print(f"base model 成功解决问题数: {base_num_solved}")
    print(f"base model 10轮内解决率: {base_accuracy:.2f}%")
    print(f"解决问题使用的turn: {turn}")
    print(f"base model 解决问题使用的turn: {base_turn}")

if __name__ == "__main__":
    main()
