#!/usr/bin/env python3
"""
测试SCSD v3在不同设置下的接受率
接受率 = (gen_length - forward次数) / gen_length * 100%
"""

import argparse
import json
import time
import torch
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from utils.generate_function_v3 import generate_v3
import re

# ANSI color codes
class Colors:
    GREEN = '\033[92m'
    BLUE = '\033[94m'
    YELLOW = '\033[93m'
    CYAN = '\033[96m'
    MAGENTA = '\033[95m'
    RED = '\033[91m'
    RESET = '\033[0m'
    BOLD = '\033[1m'


# Few-shot examples from lm_eval Minerva math (与demo_inference_v3.py相同)
FEW_SHOT_EXAMPLES = [
    {
        "problem": "Find the domain of the expression  $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}",
        "solution": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.",
    },
    {
        "problem": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$",
        "solution": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.",
    },
    {
        "problem": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?",
        "solution": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.  If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight.  Equating this to 480 pounds, we can solve for $n$:\\n\\begin{align*}\\n30n&=480\\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.",
    },
    {
        "problem": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.",
        "solution": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.",
    },
]


def doc_to_text(problem: str) -> str:
    """Format problem according to lm_eval standard"""
    return "Problem:" + "\n" + problem + "\n\n" + "Solution:"


def create_few_shot_prompt(problem: str, use_chat_template: bool = True) -> str:
    """Create few-shot prompt with examples"""
    prompt_parts = []
    
    # Add few-shot examples
    for example in FEW_SHOT_EXAMPLES:
        prompt_parts.append(doc_to_text(example["problem"]))
        prompt_parts.append(" " + example["solution"])
    
    # Add the actual problem
    prompt_parts.append(doc_to_text(problem))
    
    full_prompt = "\n\n".join(prompt_parts)
    
    if use_chat_template:
        # For instruct models, wrap in chat template
        return [{"role": "user", "content": full_prompt}]
    else:
        return full_prompt


def test_single_problem(
    model, tokenizer, problem, gen_length, block_length, draft_length, 
    tree_strategy, num_alternatives, device, verbose=False
):
    """测试单个问题的接受率"""
    
    # Create prompt
    prompt = create_few_shot_prompt(problem, use_chat_template=True)
    formatted_input = tokenizer.apply_chat_template(
        prompt, add_generation_prompt=True, tokenize=False
    )
    
    # Encode input
    input_ids = tokenizer(formatted_input)["input_ids"]
    attention_mask = tokenizer(formatted_input)["attention_mask"]
    input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
    attention_mask = torch.tensor(attention_mask).to(device).unsqueeze(0)
    
    # Generate with v3
    generation_ids, forward_passes = generate_v3(
        input_ids=input_ids,
        attention_mask=attention_mask,
        model=model,
        gen_length=gen_length,
        block_length=block_length,
        temperature=0.0,
        cfg_scale=0.0,
        mask_id=126336,
        draft_length=draft_length,
        tree_strategy=tree_strategy,
        num_alternatives=num_alternatives,
        verbose=verbose
    )
    
    # Calculate acceptance rate
    acceptance_rate = (gen_length - forward_passes) / gen_length * 100
    
    return forward_passes, acceptance_rate


def main():
    parser = argparse.ArgumentParser(description="Test SCSD v3 Acceptance Rate")
    parser.add_argument("--model_path", type=str, 
                        default="/mnt/public/gpfs-jd/code/jinxiangqi/yf/models/GSAI-ML/LLaDA-8B-Instruct",
                        help="Path to the model")
    parser.add_argument("--num_problems", type=int, default=5, 
                        help="Number of problems to test from test.jsonl")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    parser.add_argument("--output_file", type=str, default="acceptance_rate_results.md",
                        help="Output markdown file")
    args = parser.parse_args()
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"{Colors.BOLD}Loading model...{Colors.RESET}")
    model = (
        AutoModel.from_pretrained(
            args.model_path, 
            trust_remote_code=True, 
            torch_dtype=torch.bfloat16
        )
        .to(device)
        .eval()
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path, 
        trust_remote_code=True
    )
    print(f"{Colors.GREEN}Model loaded{Colors.RESET}")
    
    # Load test problems
    print(f"{Colors.BOLD}Loading test problems...{Colors.RESET}")
    problems = []
    with open("test.jsonl", "r") as f:
        for i, line in enumerate(f):
            if i >= args.num_problems:
                break
            data = json.loads(line)
            problems.append(data["question"])
    print(f"{Colors.GREEN}Loaded {len(problems)} problems{Colors.RESET}")
    
    # Parameter configurations
    gen_lengths = [64, 128, 256, 512]
    block_lengths = [4, 8, 16, 32]
    draft_lengths = [2, 3, 4, 5, 6]
    strategies = [
        ("greedy", 0),  # (strategy_name, num_alternatives)
        ("greedy_with_alternatives", 2),  # 1个额外分支
        ("greedy_with_alternatives", 3)   # 2个额外分支
    ]
    
    # Total iterations
    total_configs = len(gen_lengths) * len(block_lengths) * len(draft_lengths) * len(strategies)
    total_iterations = total_configs * len(problems)
    
    # Results storage
    results = {}
    
    # Progress bar
    pbar = tqdm(total=total_iterations, desc="Testing configurations")
    
    # Run tests
    for gen_length in gen_lengths:
        results[gen_length] = {}
        for block_length in block_lengths:
            results[gen_length][block_length] = {}
            for draft_length in draft_lengths:
                results[gen_length][block_length][draft_length] = {}
                for strategy_name, num_alt in strategies:
                    # Test all problems
                    forward_counts = []
                    acceptance_rates = []
                    
                    for problem in problems:
                        forward_count, acceptance_rate = test_single_problem(
                            model, tokenizer, problem, 
                            gen_length, block_length, draft_length,
                            strategy_name, num_alt, device, 
                            verbose=False  # Set to False to avoid too much output
                        )
                        forward_counts.append(forward_count)
                        acceptance_rates.append(acceptance_rate)
                        pbar.update(1)
                    
                    # Average results
                    avg_forward = sum(forward_counts) / len(forward_counts)
                    avg_acceptance = sum(acceptance_rates) / len(acceptance_rates)
                    
                    # Store results
                    key = strategy_name if strategy_name == "greedy" else f"{strategy_name}(alt={num_alt})"
                    results[gen_length][block_length][draft_length][key] = {
                        "forward_passes": avg_forward,
                        "acceptance_rate": avg_acceptance
                    }
    
    pbar.close()
    
    # Generate markdown output
    print(f"\n{Colors.BOLD}Generating markdown output...{Colors.RESET}")
    
    with open(args.output_file, "w") as f:
        f.write("# SCSD v3 接受率测试结果\n\n")
        f.write(f"测试了test.jsonl前{len(problems)}道题的平均接受率\n\n")
        f.write("接受率 = (gen_length - forward次数) / gen_length × 100%\n\n")
        
        # Create tables for each gen_length
        for gen_length in gen_lengths:
            f.write(f"## Gen Length = {gen_length}\n\n")
            
            for block_length in block_lengths:
                f.write(f"### Block Length = {block_length}\n\n")
                
                # Table header
                f.write("| Draft Length | Strategy | Forward Passes | Acceptance Rate (%) |\n")
                f.write("|-------------|----------|----------------|--------------------|\n")
                
                for draft_length in draft_lengths:
                    for strategy_key in results[gen_length][block_length][draft_length]:
                        data = results[gen_length][block_length][draft_length][strategy_key]
                        f.write(f"| {draft_length} | {strategy_key} | {data['forward_passes']:.1f} | {data['acceptance_rate']:.2f}% |\n")
                
                f.write("\n")
    
    print(f"{Colors.GREEN}Results saved to {args.output_file}{Colors.RESET}")
    
    # Print summary
    print(f"\n{Colors.BOLD}Summary:{Colors.RESET}")
    print(f"Tested {len(problems)} problems")
    print(f"Total configurations tested: {total_configs}")
    print(f"Results saved to: {args.output_file}")


if __name__ == "__main__":
    main()