from tqdm import tqdm
import os
import argparse
import json
import concurrent.futures
import threading
import sqlite3
import time

from openai import OpenAI

from utils_dump import (
    apply_instruct_template,
    completion_with_nudging,
    completion_with_baseline,
    remove_redundant_repetitions,
)
from dataset_utils import extract_ans, parse_pred_ans, get_dataset, PROMPTS
from time import sleep

LOAD_CKPT = True


def exp_nudging(
    client_base: OpenAI,
    client_nudging: OpenAI,
    dataset_name: str,
    num_samples: int,
    base_model: str,
    nudging_model: str,
    max_token_total: int,
    input_data: list,
    output_data: list,
    base_temperature: float,
    nudging_temperature: float,
    base_top_p: float,
    answer_start_prompt_base: str = None,
    answer_start_prompt_nudging: str = None,
    print_intermediate_output: bool = False,
    rerun: bool = False,
    exp_prefix: str = "",
    exp: str = "nudging",
    completion_token_num=16,
    completion_token_num_nudging=16,
    top_prob_thres: float = 0.3,
    num_threads: int = 10,
    output_root_dir: str = "./outputs",
    not_doing_evaluation: bool = False,
    nudging_method: str = "top_prob",
    judge_client: OpenAI = None,
    judge_model: str = None,
    retry_min_length: int = None,
    retry_times: int = 3,
    override_eos_min_length: int = -1,
):
    print("=" * 20)
    print(f"{exp} experiments")
    print("experiment settings:")
    print(f"dataset_name: {dataset_name}")
    print(f"num_samples: {num_samples}")
    print(f"base_model: {base_model}")
    print(f"nudging_model: {nudging_model}")
    print(f"max_token_total: {max_token_total}")
    print(f"base_temperature: {base_temperature}")
    print(f"base_top_p: {base_top_p}")
    print(f"nudging_method: {nudging_method}")
    print(f"nudging_temperature: {nudging_temperature}")
    print(f"num_threads: {num_threads}")
    print(f"top probability threshold: {top_prob_thres}")
    print(f"completion_token_num_base: {completion_token_num}")
    print(f"completion_token_num_nudging: {completion_token_num_nudging}")
    if retry_min_length is not None and retry_min_length > 0:
        print(f"retry_min_length: {retry_min_length}")
        print(f"retry_times: {retry_times}")
    else:
        print("retry_min_length: None (no retry)")
    if override_eos_min_length > 0:
        print(f"override_eos_min_length: {override_eos_min_length}")
    else:
        print("override_eos_min_length: -1 (no EOS override)")
    print("=" * 20)

    # Print FINAL judge_model and judge_host if using textual_judge
    if nudging_method == "textual_judge":
        resolved_judge_model = judge_model if judge_model is not None else nudging_model
        if hasattr(judge_client, "base_url"):
            resolved_judge_host = judge_client.base_url
        else:
            resolved_judge_host = getattr(client_nudging, "base_url", None)
        print(f"[textual_judge] FINAL judge_model: {resolved_judge_model}")
        print(f"[textual_judge] FINAL judge_host: {resolved_judge_host}")

    base_dir = f"{output_root_dir}/{dataset_name}/tbase_{base_temperature}_tnudge_{nudging_temperature}_thres_{top_prob_thres}"  # for saving the txt file that contains the questions and answers
    os.makedirs(base_dir, exist_ok=True)

    all_info_base_dir = f"{base_dir}/all_info"  # for saving the json file that contains all the information
    os.makedirs(all_info_base_dir, exist_ok=True)

    base_model_name = base_model.split("/")[-1]
    nudging_model_name = nudging_model.split("/")[-1]

    if len(exp_prefix) > 0 and exp_prefix[-1] != "_":
        exp_prefix += "_"
    save_filename = (
        exp_prefix
        + f"{nudging_method}_{top_prob_thres}_thres_{exp}_{base_model_name}_{nudging_model_name}_{num_samples}_samples.txt"
    )

    save_path = os.path.join(base_dir, save_filename)
    save_path_all_info = os.path.join(
        all_info_base_dir, save_filename.replace(".txt", ".json")
    )
    db_path = save_path_all_info + ".sqlite3"
    legacy_ckpt_path = save_path_all_info + ".ckpt"

    DEBUG = False

    if DEBUG:
        import sys
        import atexit

        class Logger:
            def __init__(self, logfile_path):
                self.terminal = sys.stdout
                self.logfile = open(logfile_path, "a")

            def write(self, message):
                self.terminal.write(message)
                self.logfile.write(message)

            def flush(self):
                self.terminal.flush()
                self.logfile.flush()

            def close(self):
                self.logfile.close()

        log_path = os.path.join(base_dir, "debug.log")
        logger = Logger(log_path)
        sys.stdout = logger

        atexit.register(logger.close)

        print("DEBUGGING.")

    def process_nudging_sample(
        client_base,
        client_nudging,
        base_model,
        nudging_model,
        system_prompt_base,
        system_prompt_nudging,
        question_prompt,
        answer_start_prompt_base,
        answer_start_prompt_nudging,
        max_token_total,
        base_temperature,
        base_top_p,
        nudging_temperature,
        print_intermediate_output,
        top_prob_thres,
        q,
        a,
        dataset_name,
        judge_client,
        judge_model,
        sample_idx,
        retry_min_length=None,
        retry_times=3,
        override_eos_min_length=-1,
    ):
        question = q["input"]
        context = q["context"]

        # Initialize variables for retry logic
        best_all_info = None
        best_length = 0
        retry_count = 0

        # Check if retry is needed
        need_retry = retry_min_length is not None and retry_min_length > 0

        while retry_count <= retry_times:
            all_info = completion_with_nudging(
                base_model=base_model,
                client_base=client_base,
                client_nudging=client_nudging,
                nudging_model=nudging_model,
                system_prompt_base=system_prompt_base,
                system_prompt_nudging=system_prompt_nudging,
                question=question,
                context=context,
                question_prompt=question_prompt,
                answer_start_prompt_base=answer_start_prompt_base,
                answer_start_prompt_nudging=answer_start_prompt_nudging,
                completion_token_num=completion_token_num,
                completion_token_num_nudging=completion_token_num_nudging,
                max_token_total=max_token_total,
                max_round=max_token_total,
                base_temperature=base_temperature,
                top_p=base_top_p,
                nudging_temperature=nudging_temperature,
                print_intermediate_output=print_intermediate_output,
                top_prob_thres=top_prob_thres,
                nudging_method=nudging_method,
                judge_client=judge_client,
                judge_model=judge_model,
                override_eos_min_length=override_eos_min_length,
            )

            # Calculate current output length
            current_length = sum(
                [len(x) for x in all_info["all_per_step_stats"]["tokens"]]
            )

            # Update best result if current is longer
            if current_length > best_length:
                best_all_info = all_info
                best_length = current_length

            # Check if retry is needed
            if not need_retry or current_length >= retry_min_length:
                break

            retry_count += 1
            if retry_count <= retry_times:
                print(
                    f"Sample {sample_idx}: Output length {current_length} below threshold {retry_min_length}, retrying ({retry_count}/{retry_times})"
                )

        # Use the best result
        if best_all_info is None:
            best_all_info = all_info  # Fallback to last result if no retries were made

        if need_retry and retry_count > 0:
            print(
                f"Sample {sample_idx}: Final output length {best_length} after {retry_count} retries"
            )

        raw_answer = best_all_info["raw_answer"]
        ans_, scores = extract_ans(
            dataset_name,
            raw_answer,
            input=question,
            question_start=question_prompt,
            ans_gold=a,
        )

        best_all_info["extracted_answer"] = (
            ans_  # Currectly we don't process the answer from the model, so the extracted answer is the same as the raw answer.
        )
        best_all_info["scores"] = (
            scores  # GPT-4 scores, if any, for the extracted/raw answer.
        )
        best_all_info["gold_answer"] = a
        best_all_info["q_prefix"] = question_prompt
        return best_all_info

    def process_nudging_chunk(
        progress_bar,
        client_base,
        client_nudging,
        base_model,
        nudging_model,
        system_prompt_base,
        system_prompt_nudging,
        question_prompt,
        answer_start_prompt_base,
        answer_start_prompt_nudging,
        max_token_total,
        base_temperature,
        base_top_p,
        nudging_temperature,
        print_intermediate_output,
        input_data,
        output_data,
        dataset_name,
        judge_client,
        judge_model,
        start_idx,
        retry_min_length=None,
        retry_times=3,
        override_eos_min_length=-1,
    ):
        chunk_results = []
        for i, (q, a) in enumerate(zip(input_data, output_data)):
            all_info = process_nudging_sample(
                client_base=client_base,
                client_nudging=client_nudging,
                base_model=base_model,
                nudging_model=nudging_model,
                system_prompt_base=system_prompt_base,
                system_prompt_nudging=system_prompt_nudging,
                question_prompt=question_prompt,
                answer_start_prompt_base=answer_start_prompt_base,
                answer_start_prompt_nudging=answer_start_prompt_nudging,
                max_token_total=max_token_total,
                base_temperature=base_temperature,
                base_top_p=base_top_p,
                nudging_temperature=nudging_temperature,
                print_intermediate_output=print_intermediate_output,
                top_prob_thres=top_prob_thres,
                q=q,
                a=a,
                dataset_name=dataset_name,
                judge_client=judge_client,
                judge_model=judge_model,
                sample_idx=start_idx + i,
                retry_min_length=retry_min_length,
                retry_times=retry_times,
                override_eos_min_length=override_eos_min_length,
            )
            chunk_results.append(all_info)
            progress_bar.update(1)
        return chunk_results

    if os.path.exists(save_path) and not rerun:
        print(f"Load saved results from {save_path}")
    else:
        # delete file
        if os.path.exists(save_path):
            os.remove(save_path)
        if os.path.exists(save_path_all_info):
            os.remove(save_path_all_info)

        system_prompt_base = PROMPTS[dataset_name]["system"]
        system_prompt_nudging = PROMPTS[dataset_name]["system_nudging"]
        if answer_start_prompt_base is None:
            answer_start_prompt_base = PROMPTS[dataset_name]["answer_start"]
        if answer_start_prompt_nudging is None:
            answer_start_prompt_nudging = PROMPTS[dataset_name]["answer_start"]
        question_prompt = PROMPTS[dataset_name]["question"]

        print(f"system_prompt_base: {system_prompt_base}")
        print(f"system_prompt_nudging: {system_prompt_nudging}")
        print(f"question_prompt: {question_prompt}")
        print(f"answer_start_prompt_base: {answer_start_prompt_base}")
        print(f"answer_start_prompt_nudging: {answer_start_prompt_nudging}")
        # Setup SQLite DB
        conn = sqlite3.connect(db_path, check_same_thread=False)
        c = conn.cursor()
        # c.execute('PRAGMA journal_mode=WAL;')
        # conn.commit()
        c.execute(
            """CREATE TABLE IF NOT EXISTS results (
                        idx INTEGER PRIMARY KEY,
                        result_json TEXT
                    )"""
        )
        conn.commit()
        c.execute("SELECT idx FROM results")
        processed_indices = set(row[0] for row in c.fetchall())
        print(f"Processed indices in DB: {len(processed_indices)}")

        # If legacy .ckpt file exists, load and insert to DB
        if os.path.exists(legacy_ckpt_path) and LOAD_CKPT:
            print(
                f"Loading legacy checkpoint from {legacy_ckpt_path} and inserting to DB..."
            )
            with open(legacy_ckpt_path, "r") as f_ckpt:
                legacy_all_info_list = json.load(f_ckpt)
            for idx, entry in enumerate(legacy_all_info_list):
                if entry is not None:
                    c.execute(
                        "INSERT OR REPLACE INTO results (idx, result_json) VALUES (?, ?)",
                        (idx, json.dumps(entry)),
                    )
            conn.commit()
        print(f"Checking consistency of processed indices...")
        # Find unprocessed indices and check consistency
        c.execute("SELECT idx, result_json FROM results")
        processed_data = c.fetchall()
        processed_indices = set()
        unprocessed_indices = []

        # Check consistency for processed indices
        for idx, result_json in processed_data:
            processed_indices.add(idx)
            stored_info = json.loads(result_json)

            # Get current input data for this index
            current_q = input_data[idx]
            current_question = current_q["input"]
            current_context = current_q["context"]

            # Get stored question and context
            stored_question = stored_info.get("question", "")
            stored_context = stored_info.get("context", "")

            # Check consistency
            if stored_question != current_question or stored_context != current_context:
                print(f"ERROR: Data consistency check failed for index {idx}")
                print(f"Current question: {current_question}")
                print(f"Stored question: {stored_question}")
                print(f"Current context: {current_context}")
                print(f"Stored context: {stored_context}")
                print("Exiting due to data inconsistency.")
                exit(1)

        # Find unprocessed indices
        unprocessed_indices = [
            i for i in range(len(input_data)) if i not in processed_indices
        ]

        if len(unprocessed_indices) == 0:
            print("All samples already processed. Skipping computation.")
        else:
            print(f"Processing {len(unprocessed_indices)} unprocessed samples...")
            chunk_size = (len(unprocessed_indices) + num_threads - 1) // num_threads
            chunks = [
                unprocessed_indices[i : i + chunk_size]
                for i in range(0, len(unprocessed_indices), chunk_size)
            ]
            progress_bar = tqdm(
                total=len(unprocessed_indices), desc="Processing samples"
            )

            def process_indices(indices):
                local_conn = sqlite3.connect(db_path, check_same_thread=False)
                local_c = local_conn.cursor()
                for idx_i, idx in enumerate(indices):
                    time.sleep(0.1)
                    q = input_data[idx]
                    a = output_data[idx]
                    all_info = process_nudging_sample(
                        client_base=client_base,
                        client_nudging=client_nudging,
                        base_model=base_model,
                        nudging_model=nudging_model,
                        system_prompt_base=system_prompt_base,
                        system_prompt_nudging=system_prompt_nudging,
                        question_prompt=question_prompt,
                        answer_start_prompt_base=answer_start_prompt_base,
                        answer_start_prompt_nudging=answer_start_prompt_nudging,
                        max_token_total=max_token_total,
                        base_temperature=base_temperature,
                        base_top_p=base_top_p,
                        nudging_temperature=nudging_temperature,
                        print_intermediate_output=print_intermediate_output,
                        top_prob_thres=top_prob_thres,
                        q=q,
                        a=a,
                        dataset_name=dataset_name,
                        judge_client=judge_client,
                        judge_model=judge_model,
                        sample_idx=idx,
                        retry_min_length=retry_min_length,
                        retry_times=retry_times,
                        override_eos_min_length=override_eos_min_length,
                    )
                    # Insert result into DB
                    local_c.execute(
                        "INSERT OR REPLACE INTO results (idx, result_json) VALUES (?, ?)",
                        (idx, json.dumps(all_info)),
                    )
                    local_conn.commit()
                    progress_bar.update(1)
                local_conn.close()

            with concurrent.futures.ThreadPoolExecutor(
                max_workers=num_threads
            ) as executor:
                futures = [executor.submit(process_indices, chunk) for chunk in chunks]
                for future in futures:
                    future.result()  # Wait for all to finish
            progress_bar.close()
        # Export all results from DB to JSON file
        c.execute("SELECT idx, result_json FROM results ORDER BY idx")
        all_info_list = [None] * len(input_data)
        for idx, result_json in c.fetchall():
            all_info_list[idx] = json.loads(result_json)
        with open(save_path_all_info, "w") as f:
            json.dump(all_info_list, f)
        # Save the answers and scores
        with open(save_path, "a") as fd:
            for info in all_info_list:
                fd.write(
                    "Input_q: %s\nNudging_words:\n%s\nA_model:\n%s\nScores:\n%s\nA:\n%s\n\n"
                    % (
                        info["context"] + info["question"],
                        info["all_nudging_words"],
                        info["extracted_answer"],
                        json.dumps(info["scores"], indent=4),
                        info["gold_answer"],
                    )
                )
        conn.close()

        print(f"Saved to {all_info_base_dir}")

    if not_doing_evaluation:
        return
    else:
        return parse_pred_ans(dataset_name, save_path, print_aggregated_metric=True)


def exp_baseline(
    client_base: OpenAI,
    client_proxy_chat: OpenAI,
    client_proxy_base: OpenAI,
    dataset_name: str,
    num_samples: int,
    base_model: str,
    proxy_chat_model: str,
    proxy_base_model: str,
    max_token_total: int,
    baseline_method: str,
    input_data: list,
    output_data: list,
    rerun: bool = False,
    exp_prefix: str = "",
    temperature: float = 0.0,
    num_threads: int = 100,
    output_root_dir: str = "./outputs",
    not_doing_evaluation: bool = False,  # if True, do not do evaluation
):
    print("+" * 20)
    print("Baseline experiments")
    print("experiment settings:")
    print(f"baseline_method: {baseline_method}")
    print(f"dataset_name: {dataset_name}")
    print(f"num_samples: {num_samples}")
    print(f"base_model: {base_model}")
    print(f"proxy_chat_model: {proxy_chat_model}")
    print(f"proxy_base_model: {proxy_base_model}")
    print(f"max_token_total: {max_token_total}")
    print(f"temperature: {temperature}")
    print(f"num_threads: {num_threads}")
    print("+" * 20)
    base_dir = f"{output_root_dir}/{dataset_name}"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    base_model_name = base_model.split("/")[-1]
    proxy_chat_model_name = proxy_chat_model.split("/")[-1]
    if len(exp_prefix) > 0 and exp_prefix[-1] != "_":
        exp_prefix += "_"
    save_filename = (
        exp_prefix
        + f"{baseline_method}_{base_model_name}_{proxy_chat_model_name}_{num_samples}_samples.txt"
    )
    save_path = os.path.join(base_dir, save_filename)
    all_info_base_dir = f"./outputs/{dataset_name}/t_{temperature}/all_info"
    os.makedirs(all_info_base_dir, exist_ok=True)
    save_path_all_info = os.path.join(
        all_info_base_dir, save_filename.replace(".txt", ".json")
    )
    all_info_list = [None] * len(input_data)

    def process_sample(
        client_base,
        client_proxy_chat,
        client_proxy_base,
        base_model,
        proxy_chat_model,
        proxy_base_model,
        baseline_method,
        max_token_total,
        instruction_prompt,
        q_prefix,
        answer_start_prompt,
        temperature,
        q,
        a,
        dataset_name,
    ):
        """Function for process a single sample"""
        context = q["context"]
        question = q["input"]

        all_info = completion_with_baseline(
            client_base=client_base,
            client_proxy_chat=client_proxy_chat,
            client_proxy_base=client_proxy_base,
            base_model=base_model,
            proxy_chat_model=proxy_chat_model,
            proxy_base_model=proxy_base_model,
            baseline_method=baseline_method,
            max_token_total=max_token_total,
            instruction_prompt=instruction_prompt,
            q_prefix=q_prefix,
            answer_start_prompt=answer_start_prompt,
            temperature=temperature,
            context=context,
            question=question,
        )
        raw_answer = all_info["raw_answer"]
        ans_, scores = extract_ans(
            dataset_name,
            raw_answer,
            input=question,
            question_start=q_prefix,
            ans_gold=a,
        )
        all_info["extracted_answer"] = ans_
        all_info["scores"] = scores
        all_info["gold_answer"] = a
        all_info["q_prefix"] = q_prefix

        return all_info

    def process_chunk(
        progress_bar,
        client_base,
        client_proxy_chat,
        client_proxy_base,
        base_model,
        proxy_chat_model,
        proxy_base_model,
        baseline_method,
        max_token_total,
        instruction_prompt,
        q_prefix,
        answer_start_prompt,
        temperature,
        input_data,
        output_data,
        dataset_name,
    ):
        chunk_results = []
        for q, a in zip(input_data, output_data):
            all_info = process_sample(
                client_base=client_base,
                client_proxy_chat=client_proxy_chat,
                client_proxy_base=client_proxy_base,
                base_model=base_model,
                proxy_chat_model=proxy_chat_model,
                proxy_base_model=proxy_base_model,
                baseline_method=baseline_method,
                max_token_total=max_token_total,
                instruction_prompt=instruction_prompt,
                q_prefix=q_prefix,
                answer_start_prompt=answer_start_prompt,
                temperature=temperature,
                q=q,
                a=a,
                dataset_name=dataset_name,
            )
            chunk_results.append(all_info)
            progress_bar.update(1)
        return chunk_results

    if os.path.exists(save_path) and not rerun:
        print(f"Load saved results from {save_path}")
    else:
        # delete file
        if os.path.exists(save_path):
            os.remove(save_path)
        if os.path.exists(save_path_all_info):
            os.remove(save_path_all_info)
        instruction_prompt = PROMPTS[dataset_name]["system"]
        answer_start_prompt = PROMPTS[dataset_name]["answer_start"]
        q_prefix = PROMPTS[dataset_name]["question"]
        print(f"instruction_prompt: {instruction_prompt}")
        print(f"q_prefix: {q_prefix}")
        print(f"answer_start_prompt: {answer_start_prompt}")

        chunk_size = len(input_data) // num_threads
        progress_bar = tqdm(total=len(input_data), desc="Processing samples")
        futures = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
            for i in range(0, len(input_data), chunk_size):
                chunk_input_data = input_data[i : i + chunk_size]
                chunk_output_data = output_data[i : i + chunk_size]
                futures.append(
                    (
                        executor.submit(
                            process_chunk,
                            progress_bar,
                            client_base,
                            client_proxy_chat,
                            client_proxy_base,
                            base_model,
                            proxy_chat_model,
                            proxy_base_model,
                            baseline_method,
                            max_token_total,
                            instruction_prompt,
                            q_prefix,
                            answer_start_prompt,
                            temperature,
                            chunk_input_data,
                            chunk_output_data,
                            dataset_name,
                        ),
                        i,
                    )
                )

            for future, start_index in futures:
                result = future.result()
                all_info_list[start_index : start_index + len(result)] = result

        # Save results
        with open(save_path, "a") as fd:
            for info in all_info_list:
                fd.write(
                    "Input_q: %s\nA_model:\n%s\nScores:\n%s\nA:\n%s\n\n"
                    % (
                        info["context"] + info["question"],
                        info["extracted_answer"],
                        json.dumps(info["scores"], indent=4),
                        info["gold_answer"],
                    )
                )

        with open(save_path_all_info, "w") as f:
            json.dump(all_info_list, f)
        # Close the progress bar
        progress_bar.close()
    if not_doing_evaluation:
        return
    else:
        return parse_pred_ans(dataset_name, save_path, print_aggregated_metric=True)


def exp_single_model(
    client: OpenAI,
    dataset_name: str,
    num_samples: int,
    model: str,
    max_token_total: int,
    input_data: list,
    output_data: list,
    rerun: bool = False,
    exp_prefix: str = "",
    temperature: float = 0.0,
    top_p: float = 0.9,
    num_threads: int = 100,
    model_type: str = "nudging",
    output_root_dir: str = "./outputs",
    not_doing_evaluation: bool = False,  # if True, do not do evaluation
):
    print("*" * 20)
    print(f"{model_type} only experiments")
    print("experiment settings:")
    print(f"dataset_name: {dataset_name}")
    print(f"num_samples: {num_samples}")
    print(f"{model_type}_model: {model}")
    print(f"max_token_total: {max_token_total}")
    print(f"temperature: {temperature}")
    print(f"top_p: {top_p}")
    print(f"num_threads: {num_threads}")
    print("*" * 20)

    base_dir = f"{output_root_dir}/{dataset_name}"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    model_name = model.split("/")[-1]
    if len(exp_prefix) > 0 and exp_prefix[-1] != "_":
        exp_prefix += "_"
    filename = exp_prefix + f"{model_name}_{num_samples}_t{temperature}_samples.txt"
    save_path = os.path.join(base_dir, filename)
    all_info_base_dir = f"{output_root_dir}/{dataset_name}/t_{temperature}/all_info"
    os.makedirs(all_info_base_dir, exist_ok=True)
    save_path_all_info = os.path.join(
        all_info_base_dir, filename.replace(".txt", ".json")
    )
    all_info_list = [None] * len(input_data)

    def process_sample(
        client,
        model,
        max_token_total,
        instruction_prompt,
        q_prefix,
        answer_start_prompt,
        temperature,
        q,
        a,
        dataset_name,
    ):
        all_info = {}
        prompt_q = q["context"] + q_prefix + q["input"]


        prompt = apply_instruct_template(
            model_name=model,
            system_prompt=instruction_prompt,
            instruct_prompt=prompt_q,
            response_prompt=answer_start_prompt,
        )
        # print("[debug] prompt: ", prompt)
        response = client.completions.create(
            model=model,
            max_tokens=max_token_total,
            prompt=prompt,
            temperature=temperature,
            top_p=top_p,
        )
        ans_model = response.choices[0].text

        ans_model = remove_redundant_repetitions(ans_model)

        all_info["raw_answer"] = ans_model
        ans_, scores = extract_ans(
            dataset_name,
            ans_model,
            input=q["input"],
            question_start=q_prefix,
            ans_gold=a,
        )

        all_info["question"] = q["input"]
        all_info["context"] = q["context"]
        all_info["q_prefix"] = q_prefix
        all_info["prompted_question"] = prompt
        all_info["system_prompt_nudging"] = instruction_prompt
        all_info["full_prompt"] = prompt
        all_info["answer_start_prompt"] = answer_start_prompt
        all_info["extracted_answer"] = ans_
        all_info["scores"] = scores
        all_info["gold_answer"] = a

        return all_info

    def process_chunk(
        progress_bar,
        client,
        model,
        max_token_total,
        instruction_prompt,
        q_prefix,
        answer_start_prompt,
        temperature,
        input_data,
        output_data,
        dataset_name,
    ):
        chunk_results = []
        for q, a in zip(input_data, output_data):
            all_info = process_sample(
                client,
                model,
                max_token_total,
                instruction_prompt,
                q_prefix,
                answer_start_prompt,
                temperature,
                q,
                a,
                dataset_name,
            )
            chunk_results.append(all_info)
            progress_bar.update(1)
        return chunk_results

    if os.path.exists(save_path) and not rerun:
        print(f"Load saved results from {save_path}")
    else:
        # delete file
        if os.path.exists(save_path):
            os.remove(save_path)

        instruction_prompt = (
            PROMPTS[dataset_name]["system_nudging"]
            if model_type == "nudging"
            else PROMPTS[dataset_name]["system"]
        )
        answer_start_prompt = PROMPTS[dataset_name]["answer_start"]
        q_prefix = PROMPTS[dataset_name]["question"]
        print(f"instruction_prompt: {instruction_prompt}")
        print(f"q_prefix: {q_prefix}")
        print(f"answer_start_prompt: {answer_start_prompt}")

        chunk_size = len(input_data) // num_threads
        progress_bar = tqdm(total=len(input_data), desc="Processing samples")
        futures = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
            for i in range(0, len(input_data), chunk_size):
                chunk_input_data = input_data[i : i + chunk_size]
                chunk_output_data = output_data[i : i + chunk_size]
                futures.append(
                    (
                        executor.submit(
                            process_chunk,
                            progress_bar,
                            client,
                            model,
                            max_token_total,
                            instruction_prompt,
                            q_prefix,
                            answer_start_prompt,
                            temperature,
                            chunk_input_data,
                            chunk_output_data,
                            dataset_name,
                        ),
                        i,
                    )
                )

            for future, start_index in futures:
                result = future.result()
                all_info_list[start_index : start_index + len(result)] = result

        # Save results
        with open(save_path, "a") as fd:
            for info in all_info_list:
                fd.write(
                    "Input_q: %s\nA_model:\n%s\nScores:\n%s\nA:\n%s\n\n"
                    % (
                        info["context"] + info["question"],
                        info["extracted_answer"],
                        json.dumps(info["scores"], indent=4),
                        info["gold_answer"],
                    )
                )

        with open(save_path_all_info, "w") as f:
            json.dump(all_info_list, f)
        # Close the progress bar
        progress_bar.close()
    if not_doing_evaluation:
        return
    else:
        return parse_pred_ans(dataset_name, save_path, print_aggregated_metric=True)


def main(
    dataset_name: str,
    num_samples: int,
    num_threads: int,
    base_model: str,
    nudging_model: str,
    proxy_base_model: str,  # for the baseline methods (proxy tuning)
    proxy_chat_model: str,  # for the baseline methods (proxy tuning)
    max_token_total: int,
    base_temperature: float,
    nudging_temperature: float,
    base_top_p: float,
    nudging_top_p: float,
    print_intermediate_output: bool,
    exp: str,
    exp_prefix: str,
    split: str,
    rerun: bool,
    baseline_method: str,
    top_prob_thres: float,
    completion_token_num: int,
    completion_token_num_nudging: int,
    base_host: str = None,
    nudging_host: str = None,
    proxy_base_host: str = None,
    proxy_chat_host: str = None,
    judge_host: str = None,
    judge_model: str = None,
    use_local_host: bool = True,
    generations_per_sample: int = 10,
    output_root_dir: str = "./outputs",
    not_doing_evaluation: bool = False, 
    nudging_method: str = "top_prob",
    subset: str = None,
    retry_min_length: int = None,
    retry_times: int = 3,
    override_eos_min_length: int = -1,
):
    exp_prefix = f"split_{split}_" + exp_prefix


    input_data, output_data, input_key, output_key = get_dataset(
        dataset_name=dataset_name,
        split=split,
        num_sample=num_samples,
        generations_per_sample=generations_per_sample,
        subset=subset,
    )

    # print one example
    print("\nSample example: ")
    if input_key is not None:
        print(f"{input_key}:\n" + input_data[0]["input"])
    else:
        print(input_data[0]["input"])
    if output_key is not None:
        print(f"{output_key}:\n" + output_data[0])
    else:
        print(output_data[0])

    # set up the clients for the base model
    if use_local_host:  # local server
        openai_api_key = "EMPTY"
        client_base = OpenAI(
            api_key=openai_api_key,
            base_url=base_host,
        )
   
    if use_local_host:  # local server
        openai_api_key = "EMPTY"
        client_nudging = OpenAI(
            api_key=openai_api_key,
            base_url=nudging_host,
        )
   
    if proxy_base_model is not None:
        if use_local_host:  
            openai_api_key = "EMPTY"
            client_proxy_base = OpenAI(
                api_key=openai_api_key,
                base_url=proxy_base_host,
            )
            client_proxy_chat = OpenAI(
                api_key=openai_api_key,
                base_url=proxy_chat_host,
            )

    if exp == "nudging":
        if judge_host is not None:
            judge_client = OpenAI(base_url=judge_host, api_key=openai_api_key)
        else:
            judge_client = client_nudging
        exp_nudging(
            client_base=client_base,
            client_nudging=client_nudging,
            dataset_name=dataset_name,
            num_samples=num_samples,
            base_model=base_model,
            nudging_model=nudging_model,
            max_token_total=max_token_total,
            base_temperature=base_temperature,
            nudging_temperature=nudging_temperature,
            base_top_p=base_top_p,
            input_data=input_data,
            output_data=output_data,
            print_intermediate_output=print_intermediate_output,
            rerun=rerun,
            exp_prefix=exp_prefix,
            exp=exp,
            completion_token_num=completion_token_num,
            completion_token_num_nudging=completion_token_num_nudging,
            top_prob_thres=top_prob_thres,
            num_threads=num_threads,
            output_root_dir=output_root_dir,
            not_doing_evaluation=not_doing_evaluation,
            nudging_method=nudging_method,
            judge_client=judge_client,
            judge_model=judge_model,
            retry_min_length=retry_min_length,
            retry_times=retry_times,
            override_eos_min_length=override_eos_min_length,
        )
    elif exp == "base_only":
        exp_single_model(
            client=client_base,
            dataset_name=dataset_name,
            num_samples=num_samples,
            model=base_model,
            max_token_total=max_token_total,
            input_data=input_data,
            output_data=output_data,
            rerun=rerun,
            exp_prefix=exp_prefix,
            num_threads=num_threads,
            temperature=base_temperature,
            top_p=base_top_p,
            model_type="base",
            output_root_dir=output_root_dir,
            not_doing_evaluation=not_doing_evaluation,
        )
    elif exp == "nudging_only":
        exp_single_model(
            client=client_nudging,
            dataset_name=dataset_name,
            num_samples=num_samples,
            model=nudging_model,
            max_token_total=max_token_total,
            input_data=input_data,
            output_data=output_data,
            rerun=rerun,
            exp_prefix=exp_prefix,
            num_threads=num_threads,
            temperature=nudging_temperature,
            top_p=nudging_top_p,
            model_type="nudging",
            output_root_dir=output_root_dir,
            not_doing_evaluation=not_doing_evaluation,
        )
    elif exp == "baseline":
        exp_baseline(
            client_base=client_base,
            client_proxy_chat=client_proxy_chat,
            client_proxy_base=client_proxy_base,
            dataset_name=dataset_name,
            num_samples=num_samples,
            base_model=base_model,
            proxy_base_model=proxy_base_model,
            proxy_chat_model=proxy_chat_model,
            max_token_total=max_token_total,
            baseline_method=baseline_method,
            temperature=base_temperature,
            input_data=input_data,
            output_data=output_data,
            rerun=rerun,
            exp_prefix=exp_prefix,
            num_threads=num_threads,
            output_root_dir=output_root_dir,
            not_doing_evaluation=not_doing_evaluation,
        )
    else:
        raise ValueError(f"Unknown experiment {exp}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_name", type=str, default="gsm8k", help="The name of the dataset"
    )
    parser.add_argument(
        "--num_samples", type=int, default=None, help="The number of samples"
    )
    parser.add_argument(
        "--generations_per_sample",
        type=int,
        default=10,
        help="The number of generations per sample",
    )
    parser.add_argument(
        "--num_threads", type=int, default=20, help="The number of threads"
    )
    parser.add_argument(
        "--base_model", type=str, default=None, help="The base model to use"
    )
    parser.add_argument(
        "--nudging_model", type=str, default=None, help="The nudging model to use"
    )
    parser.add_argument(
        "--proxy_base_model", type=str, default=None, help="The proxy base model to use"
    )
    parser.add_argument(
        "--proxy_chat_model", type=str, default=None, help="The proxy chat model to use"
    )
    parser.add_argument(
        "--completion_token_num",
        type=int,
        default=16,
        help="The number of token to complete in each nudging round using the base model",
    )
    parser.add_argument(
        "--completion_token_num_nudging",
        type=int,
        default=16,
        help="The number of token to complete in each nudging round using the nudging model",
    )
    parser.add_argument(
        "--max_token_total", type=int, default=512, help="The maximum number of tokens"
    )
    parser.add_argument(
        "--base_temperature",
        type=float,
        default=0.6,
        help="The temperature for the base model",
    )
    parser.add_argument(
        "--base_top_p", type=float, default=0.9, help="The top p for the base model"
    )
    parser.add_argument(
        "--nudging_temperature",
        type=float,
        default=0.6,
        help="The temperature for the nudging model",
    )
    parser.add_argument(
        "--nudging_top_p",
        type=float,
        default=0.9,
        help="The top p for the nudging model",
    )
    parser.add_argument(
        "--baseline_method",
        type=str,
        choices=["ensemble", "proxy_tuning"],
        default="ensemble",
        help="The baseline method name",
    )
    parser.add_argument(
        "--top_prob_thres",
        type=float,
        default=0.3,
        help="The top-1 token probability threshold for top prob nudging",
    )
    parser.add_argument(
        "--exp",
        type=str,
        choices=["nudging_only", "base_only", "nudging", "baseline"],
        default="nudging_only",
        help="The experiment to run",
    )
    parser.add_argument(
        "--exp_prefix",
        type=str,
        default="",
        help="The prefix for the experiment, e.g. complex_fewshot_",
    )
    parser.add_argument("--split", type=str, default="test", help="The split to test")
    parser.add_argument(
        "--base_host", type=str, default=None, help="The base host for local models"
    )
    parser.add_argument(
        "--nudging_host",
        type=str,
        default=None,
        help="The nudging host for local models",
    )
    parser.add_argument(
        "--proxy_base_host",
        type=str,
        default=None,
        help="The proxy base host for local models",
    )
    parser.add_argument(
        "--proxy_chat_host",
        type=str,
        default=None,
        help="The proxy chat host for local models",
    )
    parser.add_argument(
        "--rerun", action="store_true", help="Whether to rerun the experiment"
    )
    parser.add_argument(
        "--print_intermediate_output",
        action="store_true",
        help="Whether to print intermediate output",
    )
    parser.add_argument(
        "--output_root_dir",
        type=str,
        default="./outputs",
        help="Whether to store outputs",
    )
    parser.add_argument(
        "--not_doing_evaluation", action="store_true", help="Whether to run evaluation"
    )
    parser.add_argument(
        "--nudging_method", type=str, default="top_prob", help="The split to test"
    )
    parser.add_argument(
        "--judge_host", type=str, default=None, help="The host for the judge model"
    )
    parser.add_argument(
        "--judge_model",
        type=str,
        default=None,
        help="The model name for the judge model",
    )
    parser.add_argument(
        "--subset",
        type=str,
        default=None,
        help="The subset to use for datasets that support it (e.g., 'curated' for novelty-bench-curated, 'Human'/'GPT' for storyarc)",
    )
    parser.add_argument(
        "--retry_min_length",
        type=int,
        default=None,
        help="Minimum length threshold for retrying. If None/-1, no retry. Otherwise, retry if output length is below this threshold",
    )
    parser.add_argument(
        "--retry_times",
        type=int,
        default=3,
        help="Maximum number of retry attempts when output length is below retry_min_length",
    )
    parser.add_argument(
        "--override_eos_min_length",
        type=int,
        default=-1,
        help="Minimum length to override EOS token behavior. If -1, no override. Otherwise, continue generation if current length is below this threshold even when EOS is detected",
    )
    args = parser.parse_args()
    args = vars(args)

    main(**args)
