from collections import Counter
import argparse
import numpy as np
import logging
import torch
import random
import time
import os
from utils import *

def extract_approaches(text):
    pattern = r"Approach \d+: (.*?)(?=Approach \d+:|}$)"
    matches = re.findall(pattern, text, re.DOTALL)

    return [m.strip() for m in matches]

def main():
    args = parse_arguments()
    print('*****************************')
    print(args)
    print('*****************************')
    fix_seed(args.random_seed)
    print("OPENAI_API_KEY:")

    # Initialize decoder class (load model and tokenizer) ...
    decoder = Decoder(args)
    print("setup data loader ...")
    dataloader = setup_data_loader(args)
    print_now()

    meta_guidance_prompt = """
    What are the possible abstract solution strategies for the above mathematical problem? Please respond purely at a conceptual level by identifying four distinct approaches, without providing specific computational details. Present your answer in the format of the following example:

Example:
{
Approach 1: Algebraic method (factorization/completing the square). Treat the equation as a quadratic form and simplify it via factorization or completing the square to find when integer solutions exist.

Approach 2: Discriminant method (rational root/integer coefficient test). View the equation as a quadratic in one variable and use the discriminant to check for rational or integer roots within the given domain.

Approach 3: Number-theoretic method (divisibility/congruence analysis). Reformulate the equation to test divisibility or apply modular arithmetic to eliminate invalid solutions based on congruence conditions.

Approach 4: Geometric method (analysis of algebraic curves). Interpret the equation as an algebraic curve and analyze its geometric properties, such as symmetry and intersections, to count integer solutions.
}    
    """

    guidence_pre = "Let's think through the question step by step using the "
    check_pre = "Please check whether the above reasoning steps conform to the given question-solving method "
    check_after = "If not, revise the reasoning to ensure it aligns with the strategy."


    total = 0
    sample_n = 4  # the number of samples for each strategy
    correct_list = []

    for i, data in enumerate(dataloader):

        print('*************************')
        print("{}st data".format(i + 1))
        pre_list = []  # self-consistency

        # Prepare question template ...
        xx, y = data
        x = "Question: " + xx[0] + "\n"
        y = y[0].strip()

        # abstract strategies
        question_meta_guidance = xx[0] + "\n" + meta_guidance_prompt
        meta_paths = decoder.decode(args, question_meta_guidance, 1)
        meta_paths_ex = meta_paths[0].message.content

        meta_path_list = extract_approaches(meta_paths_ex)

        # four strategies
        for m in range(len(meta_path_list)):
            question_after = x + "\n" + guidence_pre + meta_path_list[m] + "\n"
            z = decoder.decode(args, question_after, sample_n)   # response.choices

            for j in range(len(z)): # multiple responses
                print("\n")
                print("=========== {}-th Strategy and {}-th CoT ============".format(m + 1, j + 1))
                z2 = question_after + z[j].message.content + "\n\n" + check_pre + "(" + meta_path_list[m].split('.')[0].strip() + ")." + check_after
                check_infor = decoder.decode(args, z2, 1)
                z3 = z2 + "\n\n" + check_infor[0].message.content

                z4 = z3 + "\n" + args.direct_answer_trigger_for_zeroshot_cot + " "
                pred = decoder.decode(args, z4,1)
                pred = pred[0].message.content
                print(z4 + pred)  # print Q + A

                # Clensing of predicted answer ...
                pred = answer_cleansing(args, pred)
                pre_list.append(pred)

        print("======= Final Answer of {}st Data ========".format(i + 1))
        print("pred_list: ", pre_list)

        # Choose the most frequent answer from the list ...
        last_pre = max(pre_list, key=pre_list.count)
        print("pred : {}".format(last_pre))
        print("GT : " + y)
        print('*************************')


        # Checking answer ...
        correct = (np.array([last_pre]) == np.array([y])).sum().item()
        correct_list.append(correct)
        total += 1  # np.array([y]).size(0)

        if (args.limit_dataset_size != 0) and ((i + 1) >= args.limit_dataset_size):
            break
            # raise ValueError("Stop !!")

        # Current Accuracy:
        accuracy_cur = (sum(correct_list) * 1.0 / total) * 100
        print("Current Accuracy: : {}%".format(accuracy_cur))

    # Calculate accuracy ...
    accuracy = (sum(correct_list) * 1.0 / total) * 100
    print("Accuracy : {}%".format(accuracy))



def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT")

    parser.add_argument(
        "--api_log_file_name", type=str, default=None,
        help="mandatory argument ! json['i>=1']['j==1']['k={1,2}'][{'request', response'}]"
    )

    parser.add_argument("--random_seed", type=int, default=1, help="random seed")

    parser.add_argument(
        "--dataset", type=str, default="aqua",
        choices=["aqua", "gsm8k", "commonsensqa", "addsub", "multiarith", "strategyqa", "svamp", "singleeq",
                 "bigbench_date", "object_tracking", "coin_flip", "last_letters", "aime24", "aime25"], help="dataset used for experiment"
    )

    parser.add_argument("--minibatch_size", type=int, default=1, choices=[1],
                        help="minibatch size should be 1 because GPT-3 API takes only 1 input for each request")

    parser.add_argument("--max_num_worker", type=int, default=3, help="maximum number of workers for dataloader")

    parser.add_argument(
        "--model", type=str, default="gpt3", choices=["gpt3", "gpt3-medium", "gpt3-large", "gpt3-xl"],
        help="model used for decoding. Note that 'gpt3' are the smallest models."
    )

    parser.add_argument(
        "--method", type=str, default="zero_shot_cot",
        choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "random"], help="method"
    )
    parser.add_argument(
        "--cot_trigger_no", type=int, default=1,
        help="A trigger sentence that elicits a model to execute chain of thought"
    )
    parser.add_argument(
        "--max_length_cot", type=int, default=128,
        help="maximum length of output tokens by model for reasoning extraction"
    )
    parser.add_argument(
        "--max_length_direct", type=int, default=32,
        help="maximum length of output tokens by model for answer extraction"
    )
    parser.add_argument(
        "--limit_dataset_size", type=int, default=10,
        help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
    )
    parser.add_argument(
        "--api_time_interval", type=float, default=1, help=""
    )
    parser.add_argument(
        "--log_dir", type=str, default="./log2/", help="log directory"
    )

    args = parser.parse_args()

    if args.dataset == "aqua":
        args.dataset_path = "./dataset/AQuA/test.json"
        args.direct_answer_trigger = "\nTherefore, among A through E, the answer is"
    elif args.dataset == "gsm8k":
        args.dataset_path = "./dataset/grade-school-math/test.jsonl"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "commonsensqa":
        args.dataset_path = "./dataset/CommonsenseQA/dev_rand_split.jsonl"
        args.direct_answer_trigger = "\nTherefore, among A through E, the answer is"
        args.plausible_answer_trigger = "Choose the most plausible answer from among choices A through E."
    elif args.dataset == "addsub":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "aime24":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nStop further reasoning. Just give the final answer (arabic numerals):"
    elif args.dataset == "aime25":
        args.dataset_path = "./dataset/AddSub/AddSub.json"
        args.direct_answer_trigger = "\nStop further reasoning. Just give the final answer (arabic numerals):"
    elif args.dataset == "multiarith":
        args.dataset_path = "./dataset/MultiArith/MultiArith.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "strategyqa":
        args.dataset_path = "./dataset/StrategyQA/task.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "svamp":
        args.dataset_path = "./dataset/SVAMP/SVAMP.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "singleeq":
        args.dataset_path = "./dataset/SingleEq/questions.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "bigbench_date":
        args.dataset_path = "./dataset/Bigbench_Date/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through F, the answer is"
    elif args.dataset == "object_tracking":
        args.dataset_path = "./dataset/Bigbench_object_tracking/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through C, the answer is"
    elif args.dataset == "coin_flip":
        args.dataset_path = "./dataset/coin_flip/coin_flip.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "last_letters":
        args.dataset_path = "./dataset/last_letters/last_letters.json"
        args.direct_answer_trigger = "\nTherefore, the answer is"
    else:
        raise ValueError("dataset is not properly defined ...")

    # "Therefore, the answer ..." -> "The answer ..."
    trigger = args.direct_answer_trigger.replace("\nTherefore, ", "")
    args.direct_answer_trigger_for_zeroshot = trigger[0].upper() + trigger[1:]
    args.direct_answer_trigger_for_zeroshot_cot = args.direct_answer_trigger

    args.direct_answer_trigger_for_fewshot = "The answer is"

    if args.cot_trigger_no == 1:
        args.cot_trigger = "Let's think step by step."
    elif args.cot_trigger_no == 2:
        args.cot_trigger = "We should think about this step by step."
    elif args.cot_trigger_no == 3:
        args.cot_trigger = "First,"
    elif args.cot_trigger_no == 4:
        args.cot_trigger = "Before we dive into the answer,"
    elif args.cot_trigger_no == 5:
        args.cot_trigger = "Proof followed by the answer."
    elif args.cot_trigger_no == 6:
        args.cot_trigger = "Let's think step by step in a realistic way."
    elif args.cot_trigger_no == 7:
        args.cot_trigger = "Let's think step by step using common sense and knowledge."
    elif args.cot_trigger_no == 8:
        args.cot_trigger = "Let's think like a detective step by step."
    elif args.cot_trigger_no == 9:
        args.cot_trigger = "Let's think about this logically."
    elif args.cot_trigger_no == 10:
        args.cot_trigger = "Let's think step by step. First,"
    elif args.cot_trigger_no == 11:
        args.cot_trigger = "Let's think"
    elif args.cot_trigger_no == 12:
        args.cot_trigger = "Let's solve this problem by splitting it into steps."
    elif args.cot_trigger_no == 13:
        args.cot_trigger = "The answer is after the proof."
    elif args.cot_trigger_no == 14:
        args.cot_trigger = "Let's be realistic and think step by step."
    elif args.cot_trigger_no == 15:
        args.cot_trigger = "Approach it methodically, considering all possible cases."
    elif args.cot_trigger_no == 16:
        args.cot_trigger = "Break it down step-by-step and solve systematically."
    elif args.cot_trigger_no == 17:
        args.cot_trigger = "Let's think step by step, how to go about extracting and stitching together the last letter of each sentence."
    else:
        raise ValueError("cot_trigger_no is not properly defined ...")

    return args


if __name__ == "__main__":
    main()