import tqdm
import sys
import datetime
import math
import importlib
import json
import os

def is_equal(answer, target, type) -> bool:
    """
    Compare the generated answer with the target answer based on the specified type.

    Args:
        answer: The answer generated by the model.
        target: The target or correct answer.
        type (str): The type of the answer ('float' or 'str').

    Returns:
        bool: True if the answers match based on the type, False otherwise.
    """
    if type == 'float':
        # For floating-point numbers, check if they are close enough (tolerance is used).
        return math.isclose(answer, target)
    elif type == 'str':
        # For strings, check if the generated answer is a substring of the target.
        return answer in target

class TqdmPrintWrapper:
    """
    A wrapper class to iterate over an iterable while displaying a progress bar using tqdm.
    This class also supports printing messages during iteration without interfering with the progress bar.

    Args:
        iterable: The iterable to wrap and iterate over.
        total (int, optional): The total number of iterations expected (used by tqdm).

    Methods:
        __iter__(): Returns the iterator object.
        __next__(): Advances to the next item in the iterable, updating the progress bar.
        write(message): Prints a message without breaking the progress bar display.
    """
    def __init__(self, iterable, total=None):
        self.iterable = iterable
        self.total = total
        self.iter_obj = iter(iterable)
        self.pbar = tqdm.tqdm(total=total, file=sys.stdout)
        
    def __iter__(self):
        return self
    
    def __next__(self):
        # Update progress bar and return the next value from the iterable.
        try:
            value = next(self.iter_obj)
            # Close the progress bar when iteration is complete.
            self.pbar.update(1)
            return value
        except StopIteration:
            # Safely write a message while maintaining the progress bar.
            self.pbar.close()
            raise
    
    def write(self, message):
        self.pbar.write(message)
        sys.stdout.flush()

def get_current_time_string():
    """
    Get the current time as a formatted string.

    Returns:
        str: The current time formatted as "MMDDHHMM".
    """
    now = datetime.datetime.now()
    time_string = now.strftime("%m%d%H%M")
    return time_string

def get_math_prompt(dataset_name):
    """
    Retrieve the appropriate math prompt based on the dataset name.

    Args:
        dataset_name (str): The name of the dataset.

    Returns:
        str: The math prompt associated with the dataset.
    """
    if dataset_name in ['gsm','svamp','asdiv','multi_arith']:
        # Use a generic math prompt for common datasets.
        prompt_module_name = f"prompts.math_prompt"
        prompt_module = importlib.import_module(prompt_module_name)
        return prompt_module.MATH_PROMPT
    # Use a specific prompt for the given dataset.
    prompt_module_name = f"prompts.{dataset_name}_prompt"
    prompt_module = importlib.import_module(prompt_module_name)
    return prompt_module.MATH_PROMPT

def read_file(dataset):
    """
    Read examples from a dataset file in either JSON or JSONL format.

    Args:
        dataset (str): The name of the dataset.

    Returns:
        list: A list of examples loaded from the dataset file.
    """
    # Define the data path for the dataset.
    DATA_PATH = f'datasets/{dataset}.jsonl'
    if not os.path.exists(DATA_PATH):
        DATA_PATH = f'datasets/{dataset}.json'
        
    # Load examples based on the file extension.
    if DATA_PATH.endswith('.jsonl'):
        examples = list(map(json.loads, open(DATA_PATH)))
    elif DATA_PATH.endswith('.json'):
        examples = json.load(open(DATA_PATH))['examples']
    return examples