import re
import os
import json
import numpy as np
# from utils.deepseek_utils import DecodingArguments, get_tokens_number, ChatBot
from utils.deepseek_api_utils import DecodingArguments, get_tokens_number, ChatBot
# from utils.gpt_utils import OpenAIDecodingArguments, ChatBot
from utils.basic import remove_prefix, compute_metric, match_answers
from tqdm import tqdm


PREFIX_PROMPT = "Question:\nMary spends $3 a day on lunch in the cafeteria. She bought a $5 mouthly membership card that gives her a 20% discount. How much did she spend on lunch in December? \n\nAnswer:\nFirst, let's extract the necessary information and rewrite the question using labels.\n#1. Mary spends $3 a day on lunch in the cafeteria.\n#2. She bought a $5 membership card.\n#3. The membership card gives her a 20% discount.\n#4. How much did she spend on lunch in December?\nNext，we add the necessary knowledge from the question if needed.\n#5. December has 31 days.\nFinally, let us solve the problem step by step with reference to the question and reasoning process:\n\n#6. (by #1, #3) Step 1: Calculate the discounted daily cost. \nNormal daily cost: $3, Discount: 20%\nDiscounted daily cost = $3 × (1 − 0.20) = $3 × 0.80 = $2.40\n\n#7. (by#5 #6) Step 2: Calculate the total cost for December lunches. \nDecember has 31 days, Daily discounted cost: $2.40, Total discounted lunch cost = 31 × $2.40 = $74.40\n\n#8. (by #2) Step 3: Include the cost of the membership card. \nMembership card cost: $5\n\n#9. (by #7 #8) Step 4: Determine Mary’s total spending. \nTotal lunch cost (discounted): $74.40\nMembership card cost: $5\nTotal spent = $74.40 + $5 = $79.40\n\n#10. (by #4 #9) The original question is #4. How much did she spend on lunch in December? We do not miss information on the rewritten labels. So the answer to this question is in December Mary spent $79.40. #### 79.40\nPlease answer the following questions using the reasoning format of the above example rather than the content and follow the example by placing your final answer after ####."


# extract the final answer X from the dataset ("#### X")
def parse_final_answer(ans_str):
    """
    Extract the numbers after '#### number' from the 'answer' field in the original GSM8K.
    """
    match = re.search(r"####\s*([\d,]+(?:\.\d+)?)", ans_str)
    if match:
        match = re.search(r"####\s*([\d,]+(?:\.\d+)?)", ans_str)
    else:
        return None


def transform_gsm8k_data_incremental(raw_data, existing_data=None, output_file=None,
                                     n_samples=10, temperature=0.7, max_tokens=1024,
                                     start_index=0):
    ChatBot.init()

    decoding_args = DecodingArguments(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=1.0,
        n=1,
    )

    processed_indices = {sample["index_number"] for sample in (existing_data or [])}
    transformed_data = existing_data or []

    for local_idx, sample in enumerate(tqdm(raw_data, desc="Processing", unit="sample", initial=len(existing_data or []), total=len(raw_data))):
        global_idx = start_index + local_idx
        if global_idx in processed_indices:
            print(f"[Skip] Sample #{global_idx} already processed.")
            continue

        question = sample["question"]
        answer = sample["answer"]
        final_ans = parse_final_answer(answer)

        model_input = (
            PREFIX_PROMPT
            + "\n\nQuestion:\n"
            + question
            + "\n\nAnswer:\nFirst, let's extract the necessary information and rewrite the question using labels.\n\n"
        )

        model_outputs = []
        pred_answers = []

        for _ in range(n_samples):
            outputs = ChatBot.call_chat_deepseek(
                prompt=model_input,
                decoding_args=decoding_args,
                return_list=True
            )
            single_output = outputs[0] if outputs else ""
            model_outputs.append(single_output)

            match = re.search(r"####\s*([\d,]+(?:\.\d+)?)", single_output)
            if match:
                number_str = match.group(1).replace(",", "")
                pred_answers.append(number_str)
            else:
                pred_answers.append("Not exist!")

        per_sample_correct, majority_results, majority_corrects, majority_count = compute_metric(pred_answers, final_ans)
        gt_count = sum(per_sample_correct)
        mean_expectation = np.mean(per_sample_correct) if per_sample_correct else 0.0
        majority_correct = 1.0 if majority_corrects[0] else 0.0

        new_sample = {
            "question": question,
            "answer": answer,
            "final_answer": final_ans,
            "index_number": global_idx,
            "model_input": model_input,
            "n_responses": model_outputs,
            "candidate_answers": pred_answers,
            "majority_answer_percentage": majority_correct,
            "per_answer_correct": per_sample_correct,
            "mean_expectation": mean_expectation,
            "majority_number": majority_count,
            "ground_number": gt_count,
            "voting_results": majority_results,
            "voting_corrects": majority_corrects,
            "answers_need_verify": list(range(n_samples))
        }

        transformed_data.append(new_sample)

        # save every 10 samples
        if (len(transformed_data) % 10 == 0) and output_file:
            with open(output_file, "w", encoding="utf-8") as f:
                json.dump(transformed_data, f, indent=4, ensure_ascii=False)
            print(f"[Saved] {len(transformed_data)} samples written to {output_file}")


    return transformed_data


if __name__ == "__main__":
    raw_file = "data/dataset/gsm8k/test.json"
    output_file = "./deepseek_results/gsm8k/150.json"

    # Loading raw data
    with open(raw_file, "r", encoding="utf-8") as f:
        raw_data = json.load(f)
    raw_data = raw_data[150: 300]
    # Check if a partial result already exists
    if os.path.exists(output_file):
        with open(output_file, "r", encoding="utf-8") as f:
            existing_data = json.load(f)
        print(f"[Resume] Loaded {len(existing_data)} samples from {output_file}")
    else:
        existing_data = []
        print("[Start] No existing results found. Starting from scratch.")

    # Running incremental conversion
    transformed = transform_gsm8k_data_incremental(
        raw_data=raw_data,
        existing_data=existing_data,
        output_file=output_file,
        n_samples=10,
        temperature=0.7,
        start_index=150
    )

    # Save at the end
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(transformed, f, indent=4, ensure_ascii=False)
    print(f"[Done] Final results saved to {output_file}")