################################################################################
# 1. Imports
################################################################################
import torch
import torch.nn.functional as F
from torch.func import jvp, vjp 
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import re
import copy
import pandas as pd
import gc
import time
import torch.multiprocessing as mp
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List, Set, Any


# TextGrad 替代了原本的数值梯度优化，不再需要 peft 微调参数
from peft import get_peft_model, LoraConfig, TaskType
from peft.tuners.lora import LoraLayer
try:
    from scipy.sparse.linalg import cg, LinearOperator
    import numpy as np
except ImportError:
    print("To use the SciPy solver, please install it first by running `pip install scipy`.")
    exit()

# Import for dataset loading
try:
    from datasets import Dataset, get_dataset_config_names, load_dataset, concatenate_datasets
except ImportError:
    print("To load the dataset, please install it first by running `pip install datasets`.")
    exit()

try:
    from math_verify.errors import TimeoutException
    from math_verify.metric import math_metric
    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
    pass 

################################################################################
# 2. Reward Function and Data Loading (User Provided)
################################################################################

# String normalization functions from lm-evaluation-harness
def remove_boxed(s):
    if s is None: return None
    if "\\boxed " in s:
        left = "\\boxed "
        return s[s.find(left) + len(left):].split("$")[0]
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except (AssertionError, IndexError):
        return None

def last_boxed_only_string(string):
    if string is None: return None
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None
    
    # Handle both "\boxed " and "\boxed{}"
    if string[idx+6] == ' ': # handles "\boxed "
        return string[idx:]

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    return None if right_brace_idx is None else string[idx : right_brace_idx + 1]

def fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except AssertionError:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string

def fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        if string == "{}/{}".format(a, b):
            new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
            return new_string
    except (AssertionError, ValueError):
        pass
    return string

def remove_right_units(string):
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        return splits[0]
    return string

def fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split and split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string

def strip_string(string):
    string = str(string)
    string = string.replace("\n", "")
    string = string.replace("\\!", "")
    string = string.replace("\\\\", "\\")
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")
    string = string.replace("\\$", "")
    string = string.replace("\\%", "")
    string = string.replace("\%", "")
    string = remove_right_units(string)
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    if len(string) > 0 and string[0] == ".":
        string = "0" + string
    if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
        string = string.split("=")[1]
    string = fix_sqrt(string)
    string = string.replace(" ", "")
    string = fix_fracs(string)
    string = fix_a_slash_b(string)
    return string

def is_equiv(str1, str2):
    if str1 is None and str2 is None: return True
    if str1 is None or str2 is None: return False
    try:
        return strip_string(str1) == strip_string(str2)
    except Exception:
        return str1 == str2

def compute_score(solution_str, ground_truth) -> float:
    retval = 0.0
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if is_equiv(answer, ground_truth):
                retval = 1.0
    except Exception:
        pass
    return retval

class MathVerifyRewardModel:
    def __init__(self, ground_truth_answer: str):
        self.ground_truth_answer = ground_truth_answer

    def get_reward(self, response_text: str) -> float:
        return 1.0 if compute_score(response_text, self.ground_truth_answer) == 1.0 else -1.0


################################################################################
# math
################################################################################
instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
def get_math_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--nlile--hendrycks-MATH-benchmark", split="test")
        print(data[0])
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_math_500_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--HuggingFaceH4--MATH-500", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': x['solution'],
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None

def get_AIME25_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--math-ai--aime25", split="test")
        data = data.map(lambda x: { 
            'problem': x['problem'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        # Filter out examples where the answer could not be extracted
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the dataset path is correct and the 'datasets' library is installed.")
        return None 

def split_by_last_separator(s: str) -> list:
    """
    用最后一个 '.' 或最后一个 '\n' 来分割字符串。
    """
    last_dot_index = s.rfind('.')
    last_newline_index = s.rfind('\n')
    split_index = max(last_dot_index, last_newline_index)

    if split_index == -1:
        return [s] 
    else:
        first_part = s[:split_index]
        second_part = s[split_index + 1:] 
        return [first_part, second_part]

def get_mathdial_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--eth-nlped--mathdial", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following, 
            'solution': split_by_last_separator(x['ground_truth'].strip())[0],
            'answer': split_by_last_separator(x['ground_truth'].strip())[0]
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None     

def get_gsm8k_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ReDit/dataset/gsm8k", 'main', split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': x['answer'].rsplit('####', 1)[0],
            'answer': x['answer'].rsplit('####', 1)[1]
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None   


################################################################################
# general task
################################################################################

def get_gpqa_questions():
    try:
        data = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--fingertap--GPQA-Diamond", split="test")
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None   

def get_mmlu_redux_questions():
    try:
        dataset_name = "/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--edinburgh-dawg--mmlu-redux"
        subset_names = get_dataset_config_names(dataset_name)
        all_splits = []
        for name in subset_names:
            subset_dict = load_dataset(dataset_name, name)
            for split in subset_dict.values():
                all_splits.append(split)
        
        data = concatenate_datasets(all_splits)

        def format_example(x):
            choices = x['choices']
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)
            answer_index = x['answer']
            answer_letter = option_letters[answer_index] if 0 <= answer_index < len(choices) else None
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",
                'answer': answer_letter
            }

        data = data.map(format_example) 
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

def get_superGPQA_questions():
    try:
        dataset_name = "/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--m-a-p--SuperGPQA"
        subset_names = get_dataset_config_names(dataset_name)
        all_splits = []
        for name in subset_names:
            subset_dict = load_dataset(dataset_name, name)
            for split in subset_dict.values():
                all_splits.append(split)
        data = concatenate_datasets(all_splits)

        def format_example(x):
            choices = x['options']
            option_letters = [chr(ord('A') + i) for i in range(len(choices))]
            formatted_choices = [f"{letter}. {choice}" for letter, choice in zip(option_letters, choices)]
            problem_text = f"{x['question']}\n" + "\n".join(formatted_choices)
            answer_letter = x['answer_letter']
            
            return {
                'problem': problem_text + instruction_following,
                'solution': "",
                'answer': answer_letter
            }

        data = data.map(format_example) 
        return data.filter(lambda x: x['answer'] is not None)

    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

################################################################################
# 多语言
################################################################################

def transform_to_qa_format(dataset: Dataset) -> Dataset:
    language_columns = [col for col in dataset.column_names if col != 'answer']
    def generate_qa_pairs():
        for example in dataset:
            answer = example['answer']
            for lang in language_columns:
                question = example[lang]
                if question is not None and str(question).strip():
                    yield {
                        'language': lang,
                        'question': question,
                        'answer': answer
                    }
    df = pd.DataFrame(generate_qa_pairs())
    qa_dataset = Dataset.from_pandas(df)
    return qa_dataset

def get_M_IMO_questions():
    try:
        original_dataset = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--amphora--MCLM", "M-IMO", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

def get_MT_AIME24_questions():
    try:
        original_dataset = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--amphora--MCLM", "MT-AIME2024", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

def get_MT_MATH100_questions():
    try:
        original_dataset = load_dataset("/mnt/bn/code-generation-100t-hl/weichenxing/ROSA/dataset/datasets--amphora--MCLM", "MT-MATH100", split="test")
        data = transform_to_qa_format(original_dataset)
        data = data.map(lambda x: { 
            'problem': x['question'] + instruction_following,
            'solution': "",
            'answer': x['answer']
        }) 
        return data.filter(lambda x: x['answer'] is not None)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None
        
################################################################################
# 3.1 TextGrad Optimizer and Structures
################################################################################

# --- Basic Structures for TextGrad ---

class TextualGradient:
    def __init__(self, value: str):
        self.value = value

class Variable:
    def __init__(self, value: str, role_description: str = ""):
        self.value = value
        self.role_description = role_description
        self.gradients: Set[TextualGradient] = set()
    
    def set_value(self, new_value: str):
        self.value = new_value
        # Usually, gradients are cleared by the optimizer, but for safety in this loop
        # we can consider them consumed when the value updates.
    
    def add_gradient(self, gradient: TextualGradient):
        self.gradients.add(gradient)

# --- Prompts and Constants ---

GLOSSARY_TEXT = """
### Glossary of tags that will be sent to you:
# - <LM_SYSTEM_PROMPT>: The system prompt for the language model.
# - <LM_INPUT>: The input to the language model.
# - <LM_OUTPUT>: The output of the language model.
# - <FEEDBACK>: The feedback to the variable.
# - <CONVERSATION>: The conversation history.
# - <FOCUS>: The focus of the optimization.
# - <ROLE>: The role description of the variable."""

OPTIMIZER_SYSTEM_PROMPT = (
    "You are part of an optimization system that improves text (i.e., variable) by analyzing how the responses evolve across multiple iterations. "
    "Your goal is not just to make a single improvement, but to ensure that the variable evolves naturally and meaningfully over time. "
    "You MUST give your response by sending the improved variable between {new_variable_start_tag} {{improved variable}} {new_variable_end_tag} tags. "
    f"{GLOSSARY_TEXT}"
)

def construct_tgd_prompt(**kwargs):
    # 简单实现，用于演示
    prompt = f"Here is the variable you need to improve:\n<VARIABLE>{kwargs.get('variable_value')}</VARIABLE>\n\n"
    prompt += f"Here is the feedback:\n<FEEDBACK>{kwargs.get('variable_grad')}</FEEDBACK>\n\n"
    prompt += "Improve the variable based on the feedback."
    return prompt

# --- Optimizer Classes ---

class Optimizer(ABC):
    def __init__(self, parameters: List[Variable]):
        self.parameters = parameters
    def zero_grad(self):
        for p in self.parameters:
            p.gradients = set()
    @abstractmethod
    def step(self): pass

class TextualGradientDescent(Optimizer):
    def __init__(self, 
                 parameters: List[Variable], 
                 model,          
                 tokenizer,      
                 verbose: int = 0, 
                 constraints: List[str] = None,
                 new_variable_tags: List[str] = None):
        
        super().__init__(parameters)
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device
        self.verbose = verbose
        self.constraints = constraints or []
        self.new_variable_tags = new_variable_tags or ["<IMPROVED_VARIABLE>", "</IMPROVED_VARIABLE>"]
        self.optimizer_system_prompt = OPTIMIZER_SYSTEM_PROMPT.format(
            new_variable_start_tag=self.new_variable_tags[0],
            new_variable_end_tag=self.new_variable_tags[1]
        )
        self.past_values_dict = defaultdict(list)

    def _get_grad_text(self, variable):
        # 获取梯度的简化版
        grads = [g.value for g in variable.gradients]
        return "\n".join(grads)

    def step(self):
        """核心：使用 model.generate 进行优化步骤"""
        for parameter in self.parameters:
            # 1. 构造 Prompt
            prompt_content = construct_tgd_prompt(
                variable_value=parameter.value,
                variable_grad=self._get_grad_text(parameter),
                past_values="\n".join(self.past_values_dict[parameter]),
                new_variable_start_tag=self.new_variable_tags[0],
                new_variable_end_tag=self.new_variable_tags[1]
            )
            
            # --- 打印出给优化器的 Prompt (Added by user request) ---
            # print("\n" + "="*40)
            # print("[Optimizer Step] Sending the following Prompt to LLM for improvement:")
            # print("-" * 20)
            # print(prompt_content)
            # print("-" * 20)
            # print("="*40 + "\n")
            # -----------------------------------------------------

            # 2. 使用本地模型生成优化后的结果
            messages = [
                {"role": "system", "content": self.optimizer_system_prompt},
                {"role": "user", "content": prompt_content}
            ]
            
            text_input = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = self.tokenizer([text_input], return_tensors="pt").to(self.device)

            # if self.verbose:
            #     print(f"\n[Optimizer] Optimizing variable...")

            with torch.no_grad():
                generated_ids = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=2048,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            generated_ids = [out[len(in_ids):] for in_ids, out in zip(inputs.input_ids, generated_ids)]
            response_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

            # 3. 解析结果
            try:
                start_tag = self.new_variable_tags[0]
                end_tag = self.new_variable_tags[1]
                if start_tag in response_text and end_tag in response_text:
                    new_value = response_text.split(start_tag)[1].split(end_tag)[0].strip()
                else:
                    # Fallback
                    #print(f"[Warning] Tags not found. Using full response.")
                    new_value = response_text
            except Exception as e:
                #print(f"[Error] Parsing failed: {e}")
                new_value = parameter.value

            # 4. 更新变量
            self.past_values_dict[parameter].append(parameter.value)
            parameter.set_value(new_value)
            #if self.verbose:
                #print(f"[Optimizer] Variable updated.")

# --- Critique Generation Helper ---

def generate_critique(model, tokenizer, device, problem, model_answer, ground_truth):
    """
    Generates a critique (Textual Gradient) explaining why the model's answer is wrong.
    """
    critique_prompt = (
        f"Evaluate the following solution to the geometry problem: '{problem}'. "
        f"The solution provided the answer '{model_answer}', which is verified to be INCORRECT."
        "You are an expert math grader. You have verified that the provided answer is wrong. "
        "Since the answer is incorrect, there must be a flaw in the logical reasoning or a calculation error. "
        "Please analyze the solution step-by-step, carefully check the geometry theorems applied and the arithmetic calculations. "
        "Point out specifically where the reasoning might be flawed. "
        "Be very critical."
    )
    
    messages = [{"role": "user", "content": critique_prompt}]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors='pt', max_length=2048, truncation=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=512, 
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    critique = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return critique


################################################################################
# 3.2 PiStarCalculator and External Δθ Calculation
################################################################################

class PiStarCalculator:
    def __init__(self, reward_model, tokenizer, beta=1.0, device='cuda'):
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.beta = beta
        self.device = device

    @torch.no_grad()
    def calculate_pi_star_logits(self, reference_model, target_input_ids: torch.Tensor, labels: torch.Tensor, reward: float) -> torch.Tensor:
        ref_outputs = reference_model(input_ids=target_input_ids.to(self.device))
        ref_logits = ref_outputs.logits
        log_probs_ref = F.log_softmax(ref_logits, dim=-1)
        
        nonzero_indices = (labels != -100).nonzero(as_tuple=False)
        if nonzero_indices.numel() == 0:
            return ref_logits.clone()
        answer_indices = nonzero_indices[:, 1]
        answer_token_ids = target_input_ids[0, answer_indices]

        selected_log_probs = log_probs_ref[0, answer_indices - 1, answer_token_ids]
        log_p_ref = selected_log_probs.sum()
        
        r_sequence = reward
        log_p_star = log_p_ref + (r_sequence / self.beta)
        
        total_log_bonus = log_p_star - log_p_ref
        per_token_log_bonus = total_log_bonus / len(answer_indices)
        
        pi_star_logits = ref_logits.clone()
        for i, token_idx in enumerate(answer_indices):
            correct_token_id = target_input_ids[0, token_idx]
            pi_star_logits[0, token_idx - 1, correct_token_id] += per_token_log_bonus
        return pi_star_logits

def calculate_delta_theta_externally_lora_cg(model, hidden_states, logits, pi_star_logits, labels, adapter_name='default', cg_steps=10, cg_tolerance=1e-5) -> tuple | None:
    if logits.shape[0] > 1:
        raise NotImplementedError("This function currently only supports batch_size=1.")

    nonzero_indices = (labels != -100).nonzero(as_tuple=False)
    if nonzero_indices.numel() == 0:
        return None
    answer_token_indices = nonzero_indices[:, 1]

    device = model.device
    lm_head_layer = model.get_submodule("base_model.model.lm_head")
    lora_A = lm_head_layer.lora_A[adapter_name].weight.T
    lora_B = lm_head_layer.lora_B[adapter_name].weight.T
    scaling = lm_head_layer.scaling[adapter_name]
    
    theta_lora_flat = torch.cat([lora_A.flatten(), lora_B.flatten()]).detach()
    len_A_flat = lora_A.numel()

    def F_policy_fn(theta_flat_input):
        current_A = theta_flat_input[:len_A_flat].view(lora_A.shape)
        current_B = theta_flat_input[len_A_flat:].view(lora_B.shape)
        all_policies = []
        base_lm_head = model.get_submodule("base_model.model.lm_head.base_layer")
        for t in answer_token_indices:
            hidden_state_t = hidden_states[0, t, :].detach()
            base_logits_t = base_lm_head(hidden_state_t)
            lora_output = (hidden_state_t @ current_A) @ current_B * scaling
            new_logits = base_logits_t + lora_output
            all_policies.append(F.softmax(new_logits, dim=-1))
        return torch.cat(all_policies)

    def A_prime_mvp(p_vec_np):
        p_vec = torch.from_numpy(p_vec_np).to(device)
        with torch.no_grad():
            _, jvp_result = jvp(F_policy_fn, (theta_lora_flat,), (p_vec,))
            _, vjp_fn = vjp(F_policy_fn, theta_lora_flat)
            A_prime_p = vjp_fn(jvp_result)[0]
        return A_prime_p.cpu().numpy()

    policy_diffs = []
    for t in answer_token_indices:
        pi_star_t = F.softmax(pi_star_logits[0, t, :], dim=-1)
        pi_theta_t = F.softmax(logits[0, t, :], dim=-1)
        policy_diffs.append(pi_star_t - pi_theta_t)
    b = torch.cat(policy_diffs)
    
    _, vjp_fn_b = vjp(F_policy_fn, theta_lora_flat)
    b_prime = vjp_fn_b(b)[0]
    
    # print("使用 SciPy 共轭梯度法 (CG) 求解 Δθ...")
    
    n_params = theta_lora_flat.numel()
    b_prime_np = b_prime.detach().cpu().numpy()
    
    A_operator = LinearOperator((n_params, n_params), matvec=A_prime_mvp, dtype=b_prime_np.dtype)
    
    # [MODIFIED] Re-enabled CG callback for verbose debugging
    class CgCallback:
        def __init__(self):
            self.residuals = []
        def __call__(self, rk):
            self.residuals.append(np.linalg.norm(rk))
            #print(f"  CG Residual: {self.residuals[-1]:.6f}") # Optional: uncomment for verbose CG progress

    callback = CgCallback()
    delta_theta_lora_flat_np, info = cg(A_operator, b_prime_np, maxiter=cg_steps, rtol=cg_tolerance, callback=callback)
    
    #if info != 0:
        #print(f"  [警告] CG 未在 {cg_steps} 步内收敛 (info={info})。")

    delta_theta_lora_flat = torch.from_numpy(delta_theta_lora_flat_np).to(device)
    delta_A = delta_theta_lora_flat[:len_A_flat].view(lora_A.shape)
    delta_B = delta_theta_lora_flat[len_A_flat:].view(lora_B.shape)
    
    return (delta_A, delta_B)

################################################################################
# 4. Iterative Problem Solving and Main Evaluation Loop
################################################################################

def jaccard_similarity(text1, text2):
    """Calculates the Jaccard similarity between two texts."""
    set1 = set(text1.split())
    set2 = set(text2.split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0

def solve_problem_iteratively(problem_data, model, original_base_model, tokenizer, device):
    """
    Simultaneously runs and compares FOUR strategies on the same problem:
    1. TextGrad: Context Optimization with Critique.
    2. Base Model: Generic Feedback ("Wrong answer..."), no param update.
    3. ROSA: Generic Feedback ("Wrong answer...") + LoRA Parameter Update.
    4. ROSA + TextGrad: Specific Critique Feedback + LoRA Parameter Update.
    """
    math_problem = problem_data['problem']
    standard_answer = problem_data['answer']
    MAX_INPUT_LENGTH = 5120
    
    # --- Reset ROSA Parameters (Strategy 3 & 4 specific) ---
    for module in model.modules():
        if isinstance(module, LoraLayer):
            if module.merged: module.unmerge()
            # Reset both adapters
            module.reset_lora_parameters('default', init_lora_weights=True)
            module.reset_lora_parameters('rosa_textgrad', init_lora_weights=True)

    reward_model = MathVerifyRewardModel(ground_truth_answer=standard_answer)
    
    # --- Initialize ROSA Helper ---
    pi_star_calculator = PiStarCalculator(reward_model, tokenizer, beta=1.0, device=device)

    # --- History & State Tracking ---
    # Strategy 1: TextGrad
    tg_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    tg_answer_var = None 
    tg_optimizer = None
    
    # Strategy 2: Base Model
    base_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    
    # Strategy 3: ROSA
    rosa_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]
    
    # Strategy 4: ROSA + TextGrad
    rosa_tg_messages = [{"role": "user", "content": math_problem + " " + instruction_following}]

    # Results: [TG_Solved, Base_Solved, ROSA_Solved, ROSA_TG_Solved, TG_Turn, Base_Turn, ROSA_Turn, ROSA_TG_Turn]
    result = [False, False, False, False, 11, 11, 11, 11]
    
    max_turns = 10
    
    for turn in range(1, max_turns + 1):
        # ======================================================================
        # Step 1: TextGrad Execution (Uses original_base_model)
        # ======================================================================
        if not result[0]:
            if turn == 1:
                tg_input = tokenizer.apply_chat_template(tg_messages, tokenize=False, add_generation_prompt=True)
            else:
                pass 
            
            if tg_answer_var is None:
                tg_inputs = tokenizer(tg_input, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
                with torch.no_grad():
                    gen_ids = original_base_model.generate(**tg_inputs, max_new_tokens=1024, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.eos_token_id)
                current_text = tokenizer.decode(gen_ids[0, tg_inputs.input_ids.shape[1]:], skip_special_tokens=True)
                tg_answer_var = Variable(value=current_text)
                tg_optimizer = TextualGradientDescent([tg_answer_var], original_base_model, tokenizer)
                tg_messages.append({"role": "assistant", "content": current_text})
            else:
                current_text = tg_answer_var.value
                tg_messages.append({"role": "assistant", "content": current_text})

            if reward_model.get_reward(current_text) == 1.0:
                result[0] = True
                result[4] = turn
            
            if not result[0] and turn < max_turns:
                critique = generate_critique(original_base_model, tokenizer, device, math_problem, current_text, standard_answer)
                tg_messages.append({"role": "user", "content": f"Feedback: {critique}"})
                tg_optimizer.zero_grad()
                tg_answer_var.add_gradient(TextualGradient(critique))
                tg_optimizer.step()

        # ======================================================================
        # Step 2: Base Model Execution (Uses original_base_model)
        # ======================================================================
        if not result[1]:
            base_input = tokenizer.apply_chat_template(base_messages, tokenize=False, add_generation_prompt=True)
            base_inputs = tokenizer(base_input, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
            
            with torch.no_grad():
                gen_ids = original_base_model.generate(**base_inputs, max_new_tokens=1024, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.eos_token_id)
            current_text = tokenizer.decode(gen_ids[0, base_inputs.input_ids.shape[1]:], skip_special_tokens=True)
            
            base_messages.append({"role": "assistant", "content": current_text})
            
            if reward_model.get_reward(current_text) == 1.0:
                result[1] = True
                result[5] = turn
            
            if not result[1] and turn < max_turns:
                base_messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})

        # ======================================================================
        # Step 3: ROSA Execution (Uses PEFT model 'default' adapter)
        # ======================================================================
        if not result[2]:
            model.set_adapter("default")
            model.eval()
            rosa_input = tokenizer.apply_chat_template(rosa_messages, tokenize=False, add_generation_prompt=True)
            rosa_inputs = tokenizer(rosa_input, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
            
            with torch.no_grad():
                gen_ids = model.generate(**rosa_inputs, max_new_tokens=1024, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.eos_token_id)
            current_text = tokenizer.decode(gen_ids[0, rosa_inputs.input_ids.shape[1]:], skip_special_tokens=True)
            
            rosa_messages.append({"role": "assistant", "content": current_text})
            
            reward = reward_model.get_reward(current_text)
            if reward == 1.0:
                result[2] = True
                result[6] = turn
            
            if not result[2] and turn < max_turns:
                rosa_messages.append({"role": "user", "content": "Wrong answer, please rethink and try another way of thinking!"})
                
                target_input_ids = gen_ids.to(device)
                initial_context_len = rosa_inputs.input_ids.shape[1]
                labels = target_input_ids.clone()
                labels[0, :initial_context_len] = -100
                
                pi_star_logits = pi_star_calculator.calculate_pi_star_logits(model, target_input_ids, labels, reward)
                
                model.train()
                outputs_pre = model(target_input_ids, output_hidden_states=True)
                
                delta_theta = calculate_delta_theta_externally_lora_cg(
                    model, outputs_pre.hidden_states[-1], outputs_pre.logits, pi_star_logits, labels, adapter_name='default'
                )
                
                if delta_theta:
                    delta_A, delta_B = delta_theta
                    with torch.no_grad():
                        lm_head = model.get_submodule("base_model.model.lm_head")
                        lm_head.lora_A['default'].weight.data += delta_A.T
                        lm_head.lora_B['default'].weight.data += delta_B.T

        # ======================================================================
        # Step 4: ROSA + TextGrad Execution (Uses PEFT model 'rosa_textgrad' adapter)
        # ======================================================================
        if not result[3]:
            model.set_adapter("rosa_textgrad")
            model.eval()
            rosa_tg_input = tokenizer.apply_chat_template(rosa_tg_messages, tokenize=False, add_generation_prompt=True)
            rosa_tg_inputs = tokenizer(rosa_tg_input, return_tensors='pt', max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
            
            with torch.no_grad():
                gen_ids = model.generate(**rosa_tg_inputs, max_new_tokens=1024, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.eos_token_id)
            current_text = tokenizer.decode(gen_ids[0, rosa_tg_inputs.input_ids.shape[1]:], skip_special_tokens=True)
            
            rosa_tg_messages.append({"role": "assistant", "content": current_text})
            
            reward = reward_model.get_reward(current_text)
            if reward == 1.0:
                result[3] = True
                result[7] = turn
            
            if not result[3] and turn < max_turns:
                # 1. Generate Critique (using base model for feedback)
                critique = generate_critique(original_base_model, tokenizer, device, math_problem, current_text, standard_answer)
                rosa_tg_messages.append({"role": "user", "content": f"Feedback: {critique}"})
                
                # 2. Parameter Update (on rosa_textgrad adapter)
                target_input_ids = gen_ids.to(device)
                initial_context_len = rosa_tg_inputs.input_ids.shape[1]
                labels = target_input_ids.clone()
                labels[0, :initial_context_len] = -100
                
                pi_star_logits = pi_star_calculator.calculate_pi_star_logits(model, target_input_ids, labels, reward)
                
                model.train()
                outputs_pre = model(target_input_ids, output_hidden_states=True)
                
                delta_theta = calculate_delta_theta_externally_lora_cg(
                    model, outputs_pre.hidden_states[-1], outputs_pre.logits, pi_star_logits, labels, adapter_name='rosa_textgrad'
                )
                
                if delta_theta:
                    delta_A, delta_B = delta_theta
                    with torch.no_grad():
                        lm_head = model.get_submodule("base_model.model.lm_head")
                        lm_head.lora_A['rosa_textgrad'].weight.data += delta_A.T
                        lm_head.lora_B['rosa_textgrad'].weight.data += delta_B.T

        # Break if all solved
        if result[0] and result[3] and result[4]:
            break

    return result

def worker_process(rank, problems_subset, model_name, result_queue):
    try:
        device = f"cuda:{rank}"
        torch.cuda.set_device(device)
        print(f"[GPU {rank}] 启动... 处理 {len(problems_subset)} 个问题")

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        tokenizer.truncation_side = 'left'

        # Load Base Model
        original_base_model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map={"": rank}
        )
        original_base_model.eval()

        # Load Peft Model
        peft_config = LoraConfig(r=1, lora_alpha=8, target_modules=["lm_head"], task_type=TaskType.CAUSAL_LM)
        model = get_peft_model(original_base_model, peft_config)
        
        # [MODIFIED] Add second adapter for Strategy 4
        model.add_adapter("rosa_textgrad", peft_config)
        
        for i, problem_data in enumerate(problems_subset):
            try:
                res = solve_problem_iteratively(problem_data, model, original_base_model, tokenizer, device)
                print(f" [GPU{rank}], result: {res}")
                result_queue.put({
                    'rank': rank,
                    'tg_solved': res[0], 'base_solved': res[1], 'rosa_solved': res[2], 'rosa_tg_solved': res[3],
                    'tg_turn': res[4],   'base_turn': res[5],   'rosa_turn': res[6],   'rosa_tg_turn': res[7]
                })
                
                torch.cuda.empty_cache()
                gc.collect()
            except Exception as e:
                print(f"[GPU {rank}] Error problem {i}: {e}")
                
        print(f"[GPU {rank}] 完成任务。")
    except Exception as e:
        print(f"[GPU {rank}] Critical Error: {e}")

def main():
    mp.set_start_method('spawn', force=True)
    
    start_time = time.time()
    print("="*80)
    print("开始多策略对比评估 (TextGrad vs Base vs ROSA vs ROSA+TextGrad)")
    print("="*80)
    
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0: return

    #model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--Qwen--Qwen2.5-0.5B-Instruct"
    #model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--Qwen--Qwen3-0.6B"
    #model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--Qwen--Qwen2.5-7B"
    #model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--Qwen--Qwen3-8B"
    #model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B"
    model_name = "/mnt/bn/code-generation-100t-hl/weichenxing/model/models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B"
    #model_name = "/data_train/kitwei/model/Ministral-8B-Instruct-2410"
    #model_name = "/data_train/kitwei/model/Mistral-7B-Instruct-v0.3"

    #math_dataset = get_math_questions()
    #math_dataset = get_math_500_questions()
    #math_dataset = get_AIME25_questions()
    #math_dataset = get_mathdial_questions()
    #math_dataset = get_gsm8k_questions()
    #math_dataset = get_gpqa_questions()
    #math_dataset = get_mmlu_redux_questions()
    #math_dataset = get_superGPQA_questions()
    #math_dataset = get_M_IMO_questions()
    #math_dataset = get_MT_MATH100_questions()
    #math_dataset = get_MT_AIME24_questions()
    if math_dataset is None: return
    
    num_problems = min(len(math_dataset), 500) 
    print(f"数据集大小: {num_problems}")
    problems_to_run = math_dataset.select(range(num_problems))
    
    chunk_size = (num_problems + num_gpus - 1) // num_gpus
    chunks = [problems_to_run.select(range(i*chunk_size, min((i+1)*chunk_size, num_problems))) for i in range(num_gpus) if i*chunk_size < num_problems]
    
    result_queue = mp.Queue()
    processes = []
    
    for rank, chunk in enumerate(chunks):
        p = mp.Process(target=worker_process, args=(rank, chunk, model_name, result_queue))
        p.start()
        processes.append(p)
    
    tg_solved, base_solved, rosa_solved, rosa_tg_solved = 0, 0, 0, 0
    tg_turns = [0]*11
    base_turns = [0]*11
    rosa_turns = [0]*11
    rosa_tg_turns = [0]*11
    
    pbar = tqdm(total=num_problems, desc="Evaluating")
    finished = 0
    while finished < num_problems:
        res = result_queue.get()
        if res['tg_solved']: tg_solved += 1
        if res['base_solved']: base_solved += 1
        if res['rosa_solved']: rosa_solved += 1
        if res['rosa_tg_solved']: rosa_tg_solved += 1
        
        if res['tg_turn'] <= 11: tg_turns[res['tg_turn']-1] += 1
        if res['base_turn'] <= 11: base_turns[res['base_turn']-1] += 1
        if res['rosa_turn'] <= 11: rosa_turns[res['rosa_turn']-1] += 1
        if res['rosa_tg_turn'] <= 11: rosa_tg_turns[res['rosa_tg_turn']-1] += 1
        
        finished += 1
        pbar.update(1)
    pbar.close()

    for p in processes: p.join()
    
    print("\n" + "="*80)
    print(f"评估完成 (N={num_problems}) - 耗时: {time.time()-start_time:.2f}s")
    print(f"1. TextGrad (Critique Feedback) Accuracy:   {(tg_solved/num_problems)*100:.2f}%")
    print(f"   Turns: {tg_turns}")
    print(f"2. Base Model (Generic Retry) Accuracy:     {(base_solved/num_problems)*100:.2f}%")
    print(f"   Turns: {base_turns}")
    print(f"3. ROSA (Generic + Param Update) Accuracy:  {(rosa_solved/num_problems)*100:.2f}%")
    print(f"   Turns: {rosa_turns}")
    print(f"4. ROSA + TextGrad (Critique + Param) Acc:  {(rosa_tg_solved/num_problems)*100:.2f}%")
    print(f"   Turns: {rosa_tg_turns}")
    print("="*80)

if __name__ == "__main__":
    main()