from collections import Counter
import argparse
import numpy as np
import logging
import torch
import random
import time
import os
from utils import *

def extract_approaches(text):
    pattern = r"Approach\s*\d+:\s*(.*?)(?=Approach\s*\d+:|$)"
    matches = re.findall(pattern, text, flags=re.DOTALL)

    cleaned_matches = []
    for match in matches:
        first_line = match.strip().split('\n')[0].strip()
        cleaned_matches.append(first_line)

    return cleaned_matches

def main():
    args = parse_arguments()
    print('*****************************')
    print(args)
    print('*****************************')
    fix_seed(args.random_seed)
    print("OPENAI_API_KEY:")

    # Initialize decoder class (load model and tokenizer) ...
    decoder = Decoder(args)
    print("setup data loader ...")
    dataloader = setup_data_loader(args)
    print_now()

    meta_guidance_prompt = """
What are the possible abstract solution strategies for the mathematical question? Please respond purely at a conceptual level by identifying four distinct approaches, without providing specific computational details. Format your response to match the structure and reasoning style demonstrated in the previous examples:

## Example 1 ##
There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.

## Answer 1 ##
This problem can be approached from four abstract perspectives: Rule-by-Rule Combinatorial Counting, Probability & Expected Count, Symmetry Class Partitioning, and the Algebraic/Generating-Function Method.

Specifically,
Approach 1: Rule-by-Rule Combinatorial Counting. Treat the 22-divisibility requirement as the intersection of two independent rules: (i) the last digit must be even (divisibility by 2) and (ii) the alternating-sum of the digits must be a multiple of 11 (divisibility by 11). Count permutations satisfying each rule sequentially with basic combinatorial arguments, then combine the counts to obtain N.

Approach 2: Probability & Expected Count. View the 40 320 permutations as equally likely outcomes of a uniform random shuffle. Compute the probability that a randomly chosen permutation meets the two divisibility conditions (using symmetry considerations rather than full enumeration), and multiply this probability by 40 320 to get N.

Approach 3: Symmetry Class Partitioning. Partition the set of all 8-digit permutations into symmetry classes generated by operations that preserve divisibility by 22 (e.g., cyclic rotations combined with sign changes in the alternating sum). Show that each class contributes the same, or predictably related, number of 22-divisible permutations, allowing N to be determined by analyzing just one representative class.

Approach 4: Algebraic/Generating-Function Method. Construct a bivariate generating function that records (a) the choice of last digit’s parity and (b) the alternating-sum residue modulo 11 for the remaining positions. Extract the coefficient corresponding to an alternating-sum residue of 0 and an even last digit; this coefficient equals 




## Example 2 ##
Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.

## Answer 2 ##
This problem can be approached from four abstract perspectives: Direct Base‑to‑Decimal Translation & Diophantine Equation, Modular‑Arithmetic Congruence, Euclidean Algorithm & Quotient Bounding, and the Finite Enumeration via Parameterisation.

Specifically,
Approach 1: Direct Base‑to‑Decimal Translation & Diophantine Equation. Rewrite the base‑$b$ numerals in ordinary integers: $17_{b}=b+7$ and $97_{b}=9b+7$.  Impose the divisibility condition $(b+7)\mid(9b+7)$, introduce an unknown integer quotient $k$, and obtain a linear Diophantine equation $9b+7=k(b+7)$.  Analyse the resulting integer equation for admissible $(b,k)$ pairs with $b>9$ to determine all valid bases.

Approach 2: Modular‑Arithmetic Congruence. Work entirely modulo $b+7$.  The requirement that $9b+7\equiv0\pmod{b+7}$ simplifies to a congruence in $b$ because $b\equiv-7$ modulo $b+7$.  Reduce the expression, solve the resulting congruence for $b$, and apply base‑size constraints to isolate the permissible bases.

Approach 3: Euclidean Algorithm & Quotient Bounding. Use the Euclidean algorithm to divide $9b+7$ by $b+7$ and express the remainder explicitly as a function of $b$.  Set this remainder to zero, which yields an inequality‑bounded equation for the integer quotient.  Solve the small resulting system to extract all viable $b>9$.

Approach 4: Finite Enumeration via Parameterisation. Parameterise the quotient $k$ (e.g., $1\le k\le9$ after basic size reasoning) and, for each candidate $k$, solve $9b+7=k(b+7)$ for integer $b$.  The finite list of $k$ values produces a short list of candidate bases, which can then be quickly screened against $b>9$ to obtain the complete set.




## Example 3 ##
Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\sqrt{-1}$.

## Answer 3 ##
This problem can be approached from four abstract perspectives: Euclidean‐geometry (circle–line tangency), Cartesian‐algebraic (discriminant test), Complex‐vector (translation & dot‑product), and the Parametric‐trigonometric (angle sweep).

Specifically,
Approach 1: Euclidean‐geometry (circle–line tangency). View $|25+20i-z|=5$ as a circle and $|z-4-k|=|z-3i-k|$ as the perpendicular‑bisector line of the two fixed points $4+k$ and $k+3i$.  Require that the line be tangent to the circle; impose “distance from the circle’s center to the line = radius” to isolate the admissible $k$.

Approach 2: Cartesian‐algebraic (discriminant test). Write $z=x+yi$.  Convert each modulus equation into a quadratic (for the circle) and a linear equation (for the perpendicular bisector).  Substitute the linear relation into the circle equation to obtain a single quadratic in one real variable and force its discriminant to vanish, guaranteeing exactly one intersection point and hence one solution for $k$.

Approach 3: Complex‐vector (translation & dot‑product). Translate the plane so the circle is centered at the origin (let $w=z-(25+20i)$).  Express the second condition as a dot‑product relation showing that $w$ lies on a specific line through the origin whose direction depends on $k$.  Impose that this line intersect the circle in exactly one point by setting its direction to be orthogonal to the radius at the contact point, yielding a condition on $k$.

Approach 4: Parametric‐trigonometric (angle sweep). Parameterize the circle by $z=25+20i+5e^{i\theta}$.  Substitute this form into the bisector condition to obtain an equation $f_k(\theta)=0$.  Analyze the resulting trigonometric expression and enforce that it admits exactly one angle $\theta$ on $[0,2\pi)$; translating this uniqueness requirement back gives the allowable values of $k$.
    """

    guidence_pre = "Let's think through the question step by step using the "

    check_pre = "Please check whether the above reasoning steps conform to the given question-solving method "
    check_after = "If not, revise the reasoning to ensure it aligns with the strategy."


    total = 0
    sample_n = 4  # the number of samples for each meta-strategy

    correct_list = []

    for i, data in enumerate(dataloader):
        print('*************************')
        print("{}st data".format(i + 1))
        pre_list = []  # self-consistency

        # Prepare question template ...
        xx, y = data
        x = "Question: " + xx[0] + "\n"
        y = y[0].strip()

        # abstract startegies
        q_context_pre = """## Question ##"""
        q_context_aft = """## Answer ##
This problem can be approached from four abstract perspectives: ???

Specifically,
???"""

        question_meta_guidance = meta_guidance_prompt + "\n" + q_context_pre + "\n" + xx[0] + "\n\n" + q_context_aft

        meta_paths = decoder.decode(args, question_meta_guidance, 1)
        meta_paths_ex = meta_paths[0].message.content
        meta_path_list = extract_approaches(meta_paths_ex)

        # four strategies
        for m in range(len(meta_path_list)):
            question_after = x + "\n" + guidence_pre + meta_path_list[m] + "\n"
            z = decoder.decode(args, question_after, sample_n)   # response.choices

            for j in range(len(z)):
                print("\n")
                print("=========== {}-th Strategy and {}-th CoT ============".format(m + 1, j + 1))
                z2 = question_after + z[j].message.content + "\n\n" + check_pre + "(" + meta_path_list[m].split('.')[0].strip() + ")." + check_after
                check_infor = decoder.decode(args, z2, 1)
                z3 = z2 + "\n\n" + check_infor[0].message.content

                z4 = z3 + "\n" + args.direct_answer_trigger_for_zeroshot_cot + " "
                pred = decoder.decode(args, z4,1)
                pred = pred[0].message.content
                print(z4 + pred)  # print Q + A

                # Clensing of predicted answer ...
                pred = answer_cleansing(args, pred)
                pre_list.append(pred)

        print("======= Final Answer of {}st Data ========".format(i + 1))
        print("pred_list: ", pre_list)

        # Choose the most frequent answer from the list ...
        last_pre = max(pre_list, key=pre_list.count)
        print("pred : {}".format(last_pre))
        print("GT : " + y)
        print('*************************')


        # Checking answer ...
        correct = (np.array([last_pre]) == np.array([y])).sum().item()
        correct_list.append(correct)
        total += 1  # np.array([y]).size(0)

        if (args.limit_dataset_size != 0) and ((i + 1) >= args.limit_dataset_size):
            break
            # raise ValueError("Stop !!")

        # Current Accuracy:
        accuracy_cur = (sum(correct_list) * 1.0 / total) * 100
        print("Current Accuracy: : {}%".format(accuracy_cur))

    # Calculate accuracy ...
    accuracy = (sum(correct_list) * 1.0 / total) * 100
    print("Accuracy : {}%".format(accuracy))



def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT")

    parser.add_argument(
        "--api_log_file_name", type=str, default=None,
        help="mandatory argument ! json['i>=1']['j==1']['k={1,2}'][{'request', response'}]"
    )

    parser.add_argument("--random_seed", type=int, default=1, help="random seed")

    parser.add_argument(
        "--dataset", type=str, default="aqua",
        choices=["aqua", "gsm8k", "commonsensqa", "addsub", "multiarith", "strategyqa", "svamp", "singleeq",
                 "bigbench_date", "object_tracking", "coin_flip", "last_letters", "aime24", "aime25"], help="dataset used for experiment"
    )

    parser.add_argument("--minibatch_size", type=int, default=1, choices=[1],
                        help="minibatch size should be 1 because GPT-3 API takes only 1 input for each request")

    parser.add_argument("--max_num_worker", type=int, default=3, help="maximum number of workers for dataloader")

    parser.add_argument(
        "--model", type=str, default="gpt3", choices=["gpt3", "gpt3-medium", "gpt3-large", "gpt3-xl"],
        help="model used for decoding. Note that 'gpt3' are the smallest models."
    )

    parser.add_argument(
        "--method", type=str, default="zero_shot_cot",
        choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "random"], help="method"
    )
    parser.add_argument(
        "--cot_trigger_no", type=int, default=1,
        help="A trigger sentence that elicits a model to execute chain of thought"
    )
    parser.add_argument(
        "--max_length_cot", type=int, default=128,
        help="maximum length of output tokens by model for reasoning extraction"
    )
    parser.add_argument(
        "--max_length_direct", type=int, default=32,
        help="maximum length of output tokens by model for answer extraction"
    )
    parser.add_argument(
        "--limit_dataset_size", type=int, default=10,
        help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
    )
    parser.add_argument(
        "--api_time_interval", type=float, default=1, help=""
    )
    parser.add_argument(
        "--log_dir", type=str, default="./log2/", help="log directory"
    )

    args = parser.parse_args()

    if args.dataset == "aqua":
        args.dataset_path = "./dataset/AQuA/test.json"
        args.direct_answer_trigger = "\nTherefore, among A through E, the answer is"
    elif args.dataset == "gsm8k":
        args.dataset_path = "./dataset/grade-school-math/test.jsonl"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "commonsensqa":
        args.dataset_path = "./dataset/CommonsenseQA/dev_rand_split.jsonl"
        args.direct_answer_trigger = "\nTherefore, among A through E, the answer is"
        args.plausible_answer_trigger = "Choose the most plausible answer from among choices A through E."
    elif args.dataset == "addsub":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "aime24":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nStop further reasoning. Just give the final answer (arabic numerals):"
    elif args.dataset == "aime25":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nStop further reasoning. Just give the final answer (arabic numerals):"
    elif args.dataset == "multiarith":
        args.dataset_path = "./dataset/MultiArith/MultiArith.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "strategyqa":
        args.dataset_path = "./dataset/StrategyQA/task.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "svamp":
        args.dataset_path = "./dataset/SVAMP/SVAMP.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "singleeq":
        args.dataset_path = "./dataset/SingleEq/questions.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "bigbench_date":
        args.dataset_path = "./dataset/Bigbench_Date/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through F, the answer is"
    elif args.dataset == "object_tracking":
        args.dataset_path = "./dataset/Bigbench_object_tracking/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through C, the answer is"
    elif args.dataset == "coin_flip":
        args.dataset_path = "./dataset/coin_flip/coin_flip.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "last_letters":
        args.dataset_path = "./dataset/last_letters/last_letters.json"
        args.direct_answer_trigger = "\nTherefore, the answer is"
    else:
        raise ValueError("dataset is not properly defined ...")

    # "Therefore, the answer ..." -> "The answer ..."
    trigger = args.direct_answer_trigger.replace("\nTherefore, ", "")
    args.direct_answer_trigger_for_zeroshot = trigger[0].upper() + trigger[1:]
    args.direct_answer_trigger_for_zeroshot_cot = args.direct_answer_trigger

    args.direct_answer_trigger_for_fewshot = "The answer is"

    if args.cot_trigger_no == 1:
        args.cot_trigger = "Let's think step by step."
    elif args.cot_trigger_no == 2:
        args.cot_trigger = "We should think about this step by step."
    elif args.cot_trigger_no == 3:
        args.cot_trigger = "First,"
    elif args.cot_trigger_no == 4:
        args.cot_trigger = "Before we dive into the answer,"
    elif args.cot_trigger_no == 5:
        args.cot_trigger = "Proof followed by the answer."
    elif args.cot_trigger_no == 6:
        args.cot_trigger = "Let's think step by step in a realistic way."
    elif args.cot_trigger_no == 7:
        args.cot_trigger = "Let's think step by step using common sense and knowledge."
    elif args.cot_trigger_no == 8:
        args.cot_trigger = "Let's think like a detective step by step."
    elif args.cot_trigger_no == 9:
        args.cot_trigger = "Let's think about this logically."
    elif args.cot_trigger_no == 10:
        args.cot_trigger = "Let's think step by step. First,"
    elif args.cot_trigger_no == 11:
        args.cot_trigger = "Let's think"
    elif args.cot_trigger_no == 12:
        args.cot_trigger = "Let's solve this problem by splitting it into steps."
    elif args.cot_trigger_no == 13:
        args.cot_trigger = "The answer is after the proof."
    elif args.cot_trigger_no == 14:
        args.cot_trigger = "Let's be realistic and think step by step."
    elif args.cot_trigger_no == 15:
        args.cot_trigger = "Approach it methodically, considering all possible cases."
    elif args.cot_trigger_no == 16:
        args.cot_trigger = "Break it down step-by-step and solve systematically."
    elif args.cot_trigger_no == 17:
        args.cot_trigger = "Let's think step by step, how to go about extracting and stitching together the last letter of each sentence."
    else:
        raise ValueError("cot_trigger_no is not properly defined ...")

    return args


if __name__ == "__main__":
    main()