

import logging
import openai
import argparse
import json
from sal.utils.score import score, aggregate_scores
from sal.utils.qwen_math_parser import extract_answer, math_equal

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


from sal.utils.data import get_dataset

def save_results(results, path):
    with open(path, 'w') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)



from tqdm import tqdm

def generate_with_openai(problems, system_prompt, model, api_key, temperature=1.0, max_tokens=None, top_p=1.0):
    openai.api_key = api_key
    completions = []
    for prompt in tqdm(problems, desc="OpenAI Generating"):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt},
        ]
        kwargs = dict(
            model=model,
            messages=messages,
            top_p=top_p,
        )
        # gpt-4o, gpt-4-turbo, gpt-5 only allow temperature=1
        if any(x in model for x in ["gpt-4o", "gpt-4-turbo", "gpt-5"]):
            kwargs["temperature"] = 1
        else:
            kwargs["temperature"] = temperature
        if max_tokens is not None:
            if any(x in model for x in ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo-0125", "gpt-5", "o1"]):
                kwargs["max_completion_tokens"] = max_tokens
            else:
                kwargs["max_tokens"] = max_tokens
        response = openai.chat.completions.create(**kwargs)
        completions.append(response.choices[0].message.content)
    return completions


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--openai_api_key', type=str, required=True)
    parser.add_argument('--openai_model', type=str, required=True)
    parser.add_argument('--system_prompt', type=str, default="Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.")
    parser.add_argument('--dataset_name', type=str, required=True)
    parser.add_argument('--dataset_split', type=str, default="train")
    parser.add_argument('--output_path', type=str, default="openai_results.json")
    parser.add_argument('--temperature', type=float, default=0.8)
    parser.add_argument('--max_tokens', type=int, default=4096)
    parser.add_argument('--top_p', type=float, default=1.0)
    args = parser.parse_args()

    # Construct config object for get_dataset
    class DummyConfig:
        pass
    config = DummyConfig()
    config.dataset_name = args.dataset_name
    config.dataset_split = args.dataset_split
    config.dataset_start = None
    config.dataset_end = None
    config.num_samples = None

    logger.info(f"Loading benchmark from {args.dataset_name} [{args.dataset_split}]")
    dataset = get_dataset(config)
    problems = dataset["problem"]
    answers = dataset["answer"] if "answer" in dataset.column_names else None

    logger.info(f"Generating answers with OpenAI model {args.openai_model}")
    completions = generate_with_openai(
        problems,
        args.system_prompt,
        args.openai_model,
        args.openai_api_key,
        args.temperature,
        args.max_tokens,
        args.top_p,
    )



    # Output format: one dict per problem, containing problem, completion, answer

    results = []
    if answers:
        preds = [extract_answer(c, args.dataset_name) for c in completions]
        refs = [extract_answer(a, args.dataset_name) for a in answers]
        correct = sum([math_equal(p, r) for p, r in zip(preds, refs)])
        acc = correct / len(answers)
        logger.info(f"Accuracy: {acc:.4f} ({correct}/{len(answers)})")
        for prob, comp, pred, ans in zip(problems, completions, preds, answers):
            results.append({
                "problem": prob,
                "completion": comp,
                "predict": pred,
                "answer": ans
            })
        output = {"results": results, "accuracy": acc}
    else:
        for prob, comp in zip(problems, completions):
            results.append({
                "problem": prob,
                "completion": comp
            })
        output = {"results": results}

    logger.info(f"Saving results to {args.output_path}")
    save_results(output, args.output_path)
    logger.info("Done 🔥!")

if __name__ == "__main__":
    main()
