# demo_iterative_update_lora_cg.py

################################################################################
# 1. Imports
################################################################################
import torch
import torch.nn.functional as F
from torch.func import jvp, vjp # Using jvp and vjp for matrix-free products
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import re
import copy
import pandas as pd
import gc # 新增: 用于清理CUDA缓存

from peft import get_peft_model, LoraConfig, TaskType
from peft.tuners.lora import LoraLayer

# Import for SciPy CG solver and NumPy
try:
    from scipy.sparse.linalg import cg, LinearOperator
    import numpy as np
except ImportError:
    print("To use the SciPy solver, please install it first by running `pip install scipy`.")
    exit()

# [NEW] Import for dataset loading
try:
    from datasets import Dataset, get_dataset_config_names, load_dataset, concatenate_datasets
except ImportError:
    print("To load the dataset, please install it first by running `pip install datasets`.")
    exit()

try:
    from math_verify.errors import TimeoutException
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    print("To use Math-Verify, please install it first by running `pip install math-verify`.")
    exit()

################################################################################
# 2. Reward Function and Data Loading (User Provided)
################################################################################

# String normalization functions from lm-evaluation-harness
def remove_boxed(s):
    if s is None: return None
    if "\\boxed " in s:
        left = "\\boxed "
        return s[s.find(left) + len(left):].split("$")[0]
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except (AssertionError, IndexError):
        return None

def last_boxed_only_string(string):
    if string is None: return None
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None
    
    # Handle both "\boxed " and "\boxed{}"
    if string[idx+6] == ' ': # handles "\boxed "
        return string[idx:]

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    return None if right_brace_idx is None else string[idx : right_brace_idx + 1]

def fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except AssertionError:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string

def fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        if string == "{}/{}".format(a, b):
            new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
            return new_string
    except (AssertionError, ValueError):
        pass
    return string

def remove_right_units(string):
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        return splits[0]
    return string

def fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split and split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string

def strip_string(string):
    string = str(string)
    string = string.replace("\n", "")
    string = string.replace("\\!", "")
    string = string.replace("\\\\", "\\")
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")
    string = string.replace("\\$", "")
    string = string.replace("\\%", "")
    string = string.replace("\%", "")
    string = remove_right_units(string)
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    if len(string) > 0 and string[0] == ".":
        string = "0" + string
    if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
        string = string.split("=")[1]
    string = fix_sqrt(string)
    string = string.replace(" ", "")
    string = fix_fracs(string)
    string = fix_a_slash_b(string)
    return string

def is_equiv(str1, str2):
    if str1 is None and str2 is None: return True
    if str1 is None or str2 is None: return False
    try:
        return strip_string(str1) == strip_string(str2)
    except Exception:
        return str1 == str2

def compute_score(solution_str, ground_truth) -> float:
    retval = 0.0
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if is_equiv(answer, ground_truth):
                retval = 1.0
    except Exception:
        pass
    return retval

class MathVerifyRewardModel:
    def __init__(self, ground_truth_answer: str):
        self.ground_truth_answer = ground_truth_answer

    def get_reward(self, response_text: str) -> float:
        return 1.0 if compute_score(response_text, self.ground_truth_answer) == 1.0 else -1.0


################################################################################
# math
################################################################################
instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
def get_math_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/MATH-benchmark", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_math_500_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--HuggingFaceH4--MATH-500", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_AIME25_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--math-ai--aime25", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None 

def split_by_last_separator(s: str) -> list:
    """
    用最后一个 '.' 或最后一个 '\n' 来分割字符串。

    Args:
        s: 输入的字符串。

    Returns:
        一个包含两个部分的列表。如果找不到分隔符，则返回只包含原始字符串的列表。
    """
    # 1. 找到两个分隔符最后出现的位置
    last_dot_index = s.rfind('.')
    last_newline_index = s.rfind('\n')

    # 2. 确定最终的分割点（取索引值更大的那个）
    # max() 函数可以很方便地处理其中一个或两个都找不到（返回-1）的情况
    split_index = max(last_dot_index, last_newline_index)

    # 3. 如果两个分隔符都找不到，split_index 会是 -1
    if split_index == -1:
        return [s]  # 返回包含原字符串的列表，行为与 split() 类似
    else:
        # 使用切片进行分割
        first_part = s[:split_index]
        second_part = s[split_index + 1:] # +1 是为了跳过分隔符本身
        return [first_part, second_part]

def get_mathdial_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--eth-nlped--mathdial", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following, 
            'solution': split_by_last_separator(x['ground_truth'].strip())[0],
            'answer': split_by_last_separator(x['ground_truth'].strip())[0]
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None     

def get_gsm8k_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/gsm8k", 'main', split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': x['answer'].rsplit('####', 1)[0],
            'answer': x['answer'].rsplit('####', 1)[1]
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None   


################################################################################
# general task
################################################################################

def get_gpqa_questions():
    try:
        data = load_dataset("/data_train/kitwei/grpo/dataset/datasets--fingertap--GPQA-Diamond", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None   



def get_mmlu_redux_questions():
    
    try:
        dataset_name = "/data_train/kitwei/grpo/dataset/datasets--edinburgh-dawg--mmlu-redux"
        subset_names = get_dataset_config_names(dataset_name)
        
        all_splits = []
        for name in subset_names:
            # Load each subset individually
            subset_dict = load_dataset(dataset_name, name)
            # Add all its splits (train, test, etc.) to our list
            for split in subset_dict.values():
                all_splits.append(split)
        
        # Merge everything into a single dataset
        data = concatenate_datasets(all_splits)

        def format_example(x):
            """
            A helper function to format each question, its choices, and the answer.
            """
            choices = x['choices']
            # Generate letters ('A', 'B', 'C', ...) for the options
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]

            # 1. 'problem': Combine the question with the formatted choices
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)

            # 2. 'answer': Convert the answer index to its corresponding letter
            answer_index = x['answer']
            answer_letter = option_letters[answer_index] if 0 <= answer_index < len(choices) else None
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",  # As requested, solution is always empty
                'answer': answer_letter
            }

        data = data.map(format_example) 
        
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset name 'edinburgh-dawg/mmlu-redux' is correct and the 'datasets' library is installed.")
        return None

def get_superGPQA_questions():
    
    try:
        dataset_name = "/data_train/kitwei/grpo/dataset/datasets--m-a-p--SuperGPQA"

        subset_names = get_dataset_config_names(dataset_name)
        
        all_splits = []
        for name in subset_names:
            # Load each subset individually
            subset_dict = load_dataset(dataset_name, name)
            # Add all its splits (train, test, etc.) to our list
            for split in subset_dict.values():
                all_splits.append(split)
        
        # Merge everything into a single dataset
        data = concatenate_datasets(all_splits)

        def format_example(x):
            """
            A helper function to format each question, its choices, and the answer.
            """
            choices = x['options']
            # Generate letters ('A', 'B', 'C', ...) for the options
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]

            # 1. 'problem': Combine the question with the formatted choices
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)

            # 2. 'answer': Convert the answer index to its corresponding letter
            answer_letter = x['answer_letter']
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",  # As requested, solution is always empty
                'answer': answer_letter
            }

        data = data.map(format_example) 
        
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset name 'SuperGPQA' is correct and the 'datasets' library is installed.")
        return None

################################################################################
# 多语言
################################################################################

def transform_to_qa_format(dataset: Dataset) -> Dataset:
    """
    将宽格式的多语言数据集转换为长格式的问答数据集。
    """
    
    # 确定哪些列是问题(语言列), 哪些是答案
    # 我们假设 'answer' 是答案列, 其他所有列都是不同语言的问题
    language_columns = [col for col in dataset.column_names if col != 'answer']
    
    # 定义一个生成器函数, 用于逐一产出新的 (question, answer) 对
    def generate_qa_pairs():
        for example in dataset:
            answer = example['answer']
            for lang in language_columns:
                question = example[lang]
                # 确保问题不为空或null
                if question is not None and str(question).strip():
                    yield {
                        'language': lang,
                        'question': question,
                        'answer': answer
                    }

    # 从生成器创建新的数据集
    # 使用 pd.DataFrame 作为中介可以很好地处理结构
    df = pd.DataFrame(generate_qa_pairs())
    qa_dataset = Dataset.from_pandas(df)
    
    return qa_dataset

def get_M_IMO_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "M-IMO", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_MT_AIME24_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "MT-AIME2024", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_MT_MATH100_questions():
    try:
        
        original_dataset = load_dataset("/data_train/kitwei/grpo/dataset/datasets--amphora--MCLM", "MT-MATH100", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None
        
################################################################################
# 3. PiStarCalculator and External Δθ Calculation
################################################################################

class PiStarCalculator:
    def __init__(self, reward_model, tokenizer, beta=1.0, device='cuda'):
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.beta = beta
        self.device = device

    @torch.no_grad()
    def calculate_pi_star_logits(self, reference_model, target_input_ids: torch.Tensor, labels: torch.Tensor, reward: float) -> torch.Tensor:
        ref_outputs = reference_model(input_ids=target_input_ids.to(self.device))
        ref_logits = ref_outputs.logits
        log_probs_ref = F.log_softmax(ref_logits, dim=-1)
        
        nonzero_indices = (labels != -100).nonzero(as_tuple=False)
        if nonzero_indices.numel() == 0:
            return ref_logits.clone()
        answer_indices = nonzero_indices[:, 1]
        answer_token_ids = target_input_ids[0, answer_indices]

        selected_log_probs = log_probs_ref[0, answer_indices - 1, answer_token_ids]
        log_p_ref = selected_log_probs.sum()
        
        r_sequence = reward
        log_p_star = log_p_ref + (r_sequence / self.beta)
        
        total_log_bonus = log_p_star - log_p_ref
        per_token_log_bonus = total_log_bonus / len(answer_indices)
        
        pi_star_logits = ref_logits.clone()
        for i, token_idx in enumerate(answer_indices):
            correct_token_id = target_input_ids[0, token_idx]
            pi_star_logits[0, token_idx - 1, correct_token_id] += per_token_log_bonus
        return pi_star_logits

def calculate_delta_theta_externally_lora_cg(model, hidden_states, logits, pi_star_logits, labels, cg_steps=10, cg_tolerance=1e-5) -> tuple | None:
    if logits.shape[0] > 1:
        raise NotImplementedError("This function currently only supports batch_size=1.")

    nonzero_indices = (labels != -100).nonzero(as_tuple=False)
    if nonzero_indices.numel() == 0:
        return None
    answer_token_indices = nonzero_indices[:, 1]

    device = model.device
    lm_head_layer = model.get_submodule("base_model.model.lm_head")
    lora_A = lm_head_layer.lora_A['default'].weight.T
    lora_B = lm_head_layer.lora_B['default'].weight.T
    scaling = lm_head_layer.scaling['default']
    
    theta_lora_flat = torch.cat([lora_A.flatten(), lora_B.flatten()]).detach()
    len_A_flat = lora_A.numel()

    def F_policy_fn(theta_flat_input):
        current_A = theta_flat_input[:len_A_flat].view(lora_A.shape)
        current_B = theta_flat_input[len_A_flat:].view(lora_B.shape)
        all_policies = []
        base_lm_head = model.get_submodule("base_model.model.lm_head.base_layer")
        for t in answer_token_indices:
            hidden_state_t = hidden_states[0, t, :].detach()
            base_logits_t = base_lm_head(hidden_state_t)
            lora_output = (hidden_state_t @ current_A) @ current_B * scaling
            new_logits = base_logits_t + lora_output
            all_policies.append(F.softmax(new_logits, dim=-1))
        return torch.cat(all_policies)

    def A_prime_mvp(p_vec_np):
        p_vec = torch.from_numpy(p_vec_np).to(device)
        with torch.no_grad():
            _, jvp_result = jvp(F_policy_fn, (theta_lora_flat,), (p_vec,))
            _, vjp_fn = vjp(F_policy_fn, theta_lora_flat)
            A_prime_p = vjp_fn(jvp_result)[0]
        return A_prime_p.cpu().numpy()

    policy_diffs = []
    for t in answer_token_indices:
        pi_star_t = F.softmax(pi_star_logits[0, t, :], dim=-1)
        pi_theta_t = F.softmax(logits[0, t, :], dim=-1)
        policy_diffs.append(pi_star_t - pi_theta_t)
    b = torch.cat(policy_diffs)
    
    _, vjp_fn_b = vjp(F_policy_fn, theta_lora_flat)
    b_prime = vjp_fn_b(b)[0]
    
    # print("使用 SciPy 共轭梯度法 (CG) 求解 Δθ...")
    
    n_params = theta_lora_flat.numel()
    b_prime_np = b_prime.detach().cpu().numpy()
    
    A_operator = LinearOperator((n_params, n_params), matvec=A_prime_mvp, dtype=b_prime_np.dtype)
    
    # [MODIFIED] Re-enabled CG callback for verbose debugging
    class CgCallback:
        def __init__(self):
            self.residuals = []
        def __call__(self, rk):
            self.residuals.append(np.linalg.norm(rk))
            print(f"  CG Residual: {self.residuals[-1]:.6f}") # Optional: uncomment for verbose CG progress

    callback = CgCallback()
    delta_theta_lora_flat_np, info = cg(A_operator, b_prime_np, maxiter=cg_steps, rtol=cg_tolerance, callback=callback)
    
    if info != 0:
        print(f"  [警告] CG 未在 {cg_steps} 步内收敛 (info={info})。")

    delta_theta_lora_flat = torch.from_numpy(delta_theta_lora_flat_np).to(device)
    delta_A = delta_theta_lora_flat[:len_A_flat].view(lora_A.shape)
    delta_B = delta_theta_lora_flat[len_A_flat:].view(lora_B.shape)
    
    return (delta_A, delta_B)

################################################################################
# 4. Iterative Problem Solving and Main Evaluation Loop
################################################################################

def jaccard_similarity(text1, text2):
    """Calculates the Jaccard similarity between two texts."""
    set1 = set(text1.split())
    set2 = set(text2.split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0

def solve_problem_iteratively(problem_data, model, original_base_model, tokenizer, device):
    """
    [MODIFIED] Encapsulates the 10-turn iterative solving logic with verbose printing.
    Returns True if solved, False otherwise.
    """
    # --- Per-problem setup ---
    math_problem = problem_data['problem']
    standard_answer = problem_data['answer']
    MAX_INPUT_LENGTH = 5120
    # Create a fresh model with LoRA for each problem
    for module in model.modules():
        if isinstance(module, LoraLayer):
            if module.merged: # Unmerge if layers have been merged
                module.unmerge()
            module.reset_lora_parameters('default', init_lora_weights=True)

    reward_model = MathVerifyRewardModel(ground_truth_answer=standard_answer)
    pi_star_calculator = PiStarCalculator(reward_model, tokenizer, beta=1.0, device=device)
    
    messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    base_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    max_turns = 10
    result = [False, False, 11, 11]
    for turn in range(1, max_turns + 1):
        print("\n" + "-"*25 + f" 第 {turn}/{max_turns} 轮迭代 " + "-"*25)
        
        context_text_loop = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        context_ids = tokenizer(context_text_loop, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True)['input_ids']

        base_context_text_loop = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=True)
        base_context_ids = tokenizer(base_context_text_loop, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True)['input_ids']
        
        model.eval()
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=context_ids.to(device),
                max_new_tokens=1024, top_p=0.9, temperature=1.0,
                pad_token_id=tokenizer.eos_token_id
            )
            generated_text = tokenizer.decode(generated_ids[0, context_ids.shape[1]:], skip_special_tokens=True)

            base_model_generated_ids = original_base_model.generate(
                input_ids=base_context_ids.to(device),
                max_new_tokens=1024,  top_p=0.9, temperature=1.0,
                pad_token_id=tokenizer.eos_token_id
            )
            base_model_generated_text = tokenizer.decode(base_model_generated_ids[0, base_context_ids.shape[1]:], skip_special_tokens=True)
        
        similarity = jaccard_similarity(generated_text, base_model_generated_text)
        
        print(f"模型回答: '{generated_text}'")
        print(f"原始模型回答: '{base_model_generated_text}'")
        print(f"回答相似度 (Jaccard): {similarity:.4f}")
        
        reward = reward_model.get_reward(generated_text)
        base_reward = reward_model.get_reward(base_model_generated_text)
        print(f"获得奖励: {reward}")
        print(f"base model 获得奖励 {base_reward}")

        if reward == 1.0:
            print(f"  [成功] 在第 {turn} 轮回答正确！")
            result[0] = True
            result[2] = turn
            break
        
        if base_reward == 1.0:
            print(f"  [成功] 在第 {turn} 轮回答正确！")
            result[1] = True
            if result[3] == 11:
                result[3] = turn
            base_messages.append({"role": "assistant", "content": base_model_generated_text})
            base_messages.append({"role": "user", "content": "Right answer"})
        else:
            base_messages.append({"role": "assistant", "content": base_model_generated_text})
            base_messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})


        messages.append({"role": "assistant", "content": generated_text})
        messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})
        
        target_input_ids = generated_ids.to(device)
        initial_context_len = context_ids.shape[1]
        labels = target_input_ids.clone()
        labels[0, :initial_context_len] = -100
        
        pi_star_logits = pi_star_calculator.calculate_pi_star_logits(model, target_input_ids, labels, reward)
        
        model.train()
        outputs_pre_update = model(target_input_ids, output_hidden_states=True)
        
        with torch.no_grad():
            logits_pre = outputs_pre_update.logits
            
            nonzero_indices_diag = (labels != -100).nonzero(as_tuple=False)
            if nonzero_indices_diag.numel() > 0:
                answer_indices_diag = nonzero_indices_diag[:, 1]
                pi_star = F.softmax(pi_star_logits[0, answer_indices_diag, :], dim=-1)
                pi_theta = F.softmax(logits_pre[0, answer_indices_diag, :], dim=-1)
                kl_div = F.kl_div(pi_theta.log(), pi_star, reduction='sum', log_target=False)
                tvd = 0.5 * torch.sum(torch.abs(pi_star - pi_theta))
                print(f"策略差异诊断: KL(π*||πθ)={kl_div.item():.4f}, TVD={tvd.item():.4f}")

        delta_theta_result = calculate_delta_theta_externally_lora_cg(
            model, outputs_pre_update.hidden_states[-1], outputs_pre_update.logits, pi_star_logits, labels
        )

        if not delta_theta_result:
            print("  [警告] 无法计算Δθ，跳过本轮更新。")
            continue
            
        delta_A, delta_B = delta_theta_result
        with torch.no_grad():
            lm_head_layer = model.get_submodule("base_model.model.lm_head")
            lm_head_layer.lora_A['default'].weight.data += delta_A.T
            lm_head_layer.lora_B['default'].weight.data += delta_B.T
        print(f"成功应用 Δθ 到 LoRA 矩阵 A (范数: {delta_A.norm():.4f}) 和 B (范数: {delta_B.norm():.4f})。")
    
    print(f"  [失败] 达到最大迭代次数 {max_turns}，未能解决问题。")
    return result

def main():
    print("="*80)
    print("开始在 MATH 数据集上进行迭代式评估")
    print("="*80)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #model_name = "/data_train/kitwei/model/Qwen2.5-0.5B-Instruct"
    #model_name = "/data_train/kitwei/model/Qwen3-0.6B"
    model_name = "/data_train/kitwei/model/Qwen2.5-7B-Instruct"
    #model_name = "/data_train/kitwei/model/Qwen3-8B"
    #model_name = "/data_train/kitwei/model/DeepSeek-R1-Distill-Qwen-7B"
    #model_name = "/data_train/kitwei/model/DeepSeek-R1-Distill-Llama-8B"
    #model_name = "/data_train/kitwei/model/Ministral-8B-Instruct-2410"
    #model_name = "/data_train/kitwei/model/Mistral-7B-Instruct-v0.3"
    print(f"使用设备: {device}")
    print(f"加载模型: {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    tokenizer.truncation_side = 'left'

    original_base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        #torch_dtype=torch.bfloat16,
        #attn_implementation="flash_attention_2",
        device_map="auto"
        )
    original_base_model.eval()
    
    print("Creating PEFT model with LoRA adapters...")
    peft_config = LoraConfig(r=1, lora_alpha=8, target_modules=["lm_head"], task_type=TaskType.CAUSAL_LM)
    model = get_peft_model(original_base_model, peft_config)
    model.print_trainable_parameters()


    print("模型加载完成。")

    print("\n--- 1. 加载数据集 ---")
    #math_dataset = get_math_questions()
    #math_dataset = get_math_500_questions()
    math_dataset = get_AIME25_questions()
    #math_dataset = get_mathdial_questions()
    #math_dataset = get_gsm8k_questions()
    #math_dataset = get_gpqa_questions()
    #math_dataset = get_mmlu_redux_questions()
    #math_dataset = get_superGPQA_questions()
    #math_dataset = get_M_IMO_questions()
    #math_dataset = get_MT_MATH100_questions()
    #math_dataset = get_MT_AIME24_questions()
    if math_dataset is None:
        return
    print(f"成功加载 {len(math_dataset)} 个 MATH 问题。")

    # --- Evaluation Loop ---
    num_problems_to_evaluate = min(len(math_dataset), 500) # For a quick demo
    problems_to_run = math_dataset.select(range(num_problems_to_evaluate))
    
    num_solved = 0
    turn = [0]*11
    base_num_solved = 0
    base_turn = [0]*11    
    for i, problem_data in enumerate(problems_to_run):
        print("\n" + "#"*30 + f" 评估问题 {i+1}/{num_problems_to_evaluate} " + "#"*30)
        print(f"问题: {problem_data['problem'][:100]}...")
        print(f"答案: {problem_data['answer']}")
        
        result = solve_problem_iteratively(problem_data, model, original_base_model, tokenizer, device)
        if result[0]:
            num_solved += 1
        if result[1]:
            base_num_solved += 1
        turn[result[2] - 1] += 1
        base_turn[result[3] - 1] += 1
        if device == 'cuda':
            torch.cuda.empty_cache()
            
            
    # --- Final Report ---
    print("\n" + "="*80)
    print("评估完成！")
    print("="*80)
    accuracy = (num_solved / num_problems_to_evaluate) * 100
    base_accuracy = (base_num_solved / num_problems_to_evaluate) * 100
    print(f"总共评估问题数: {num_problems_to_evaluate}")
    print(f"成功解决问题数: {num_solved}")
    print(f"10轮内解决率: {accuracy:.2f}%")
    print(f"base model 成功解决问题数: {base_num_solved}")
    print(f"base model 10轮内解决率: {base_accuracy:.2f}%")
    print(f"解决问题使用的turn: {turn}")
    print(f"base model 解决问题使用的turn: {base_turn}")

if __name__ == "__main__":
    main()
    # pi_star 更新 done
    # reward 具体化，信息量更大一点（做对了多少）done
    # slot 修改heddin state
    # 试一下其他层的参数加LoRA
    # 最好是修改1e3-1e4的参数
    #