import re
from collections import Counter
from scipy import integrate
import random

def confidence_criteria(answers : list, conf_thresh=1.0):
    """
    Evaluate the confidence level of the most common answer in a list of answers.

    Args:
        answers (list): A list of answers to evaluate.
        conf_thresh (float, optional): The confidence threshold for the most common answer. Defaults to 1.0.

    Returns:
        dict: A dictionary containing:
              - 'most_common': The most common answer found in the list.
              - 'prob': The calculated probability that the most common answer is correct.
              - 'conf': A boolean indicating whether the confidence exceeds the threshold.
    """
    if len(answers) == 0:
        return {
        'most_common' : None,
        'prob' : -1,
        'conf' : False,
        }
    most_common = Counter(answers).most_common(2)
    if len(most_common) == 1:
        a, b = most_common[0][1], 0
    else:
        a, b= most_common[0][1], most_common[1][1]
    a = float(a)
    b = float(b)
    return_dict = {
        'most_common' : most_common[0][0],
        'prob' : -1,
        'conf' : False,
    }
        
    try:
        # Calculate the probability using numerical integration
        prob =  integrate.quad(lambda x : x**(a) * (1-x)**(b), 0.5, 1)[0] / integrate.quad(lambda x : x**(a) * (1-x)**(b), 0, 1)[0]
    except Exception as e:
        # Handle any errors during numerical integration
        print(f"Error during numerical integration: {e}")
        return_dict['prob'] = -1
        return_dict['conf'] = False
        return return_dict
    return_dict['prob'] = prob
    return_dict['conf'] = prob >= conf_thresh
    return return_dict

def get_prefix(answers, conf_thresh, gens, prefix, prefix_level, frq:int = 5):
    """
    Generate a list of potential prefixes for future inference based on the confidence of previous answers.

    Args:
        answers (list): List of previous answers.
        conf_thresh (float): Confidence threshold to consider when updating prefixes.
        gens (list): List of generated text segments.
        prefix (str): Current prefix.
        prefix_level (int): Current depth level of the prefix.
        frq (int, optional): Number of recent generations to consider. Defaults to 5.

    Returns:
        tuple: Updated prefix list and prefix level.
    """
    confidence = confidence_criteria(answers, conf_thresh)
    prefix_list = []
    if confidence['conf'] == False:
        return prefix, prefix_level
    for i, gen in enumerate(gens[-frq:]):
        index = len(gens) - frq + i
        parts = gen.split("answer is")
        if len(parts) > 1:
            pre_gen = parts[0]
            if answers[index] == confidence['most_common']:
                # Split the prefix generation into sentences
                # sentences = pre_gen.split(". ")
                # sentences = re.split(r'\.\s|\.\n', pre_gen)
                # sentences = [sentence for sentence in sentences if sentence]
                parts = re.split(r'(\.\s|\.\n|\n)', pre_gen)
                sentences = []
                for i in range(0, len(parts) - 1, 2):
                    sentences.append(parts[i] + parts[i + 1])
                sentence = ""
                for j in range(prefix_level + 1):
                    if j < len(sentences) - 1:
                        sentence += sentences[j]
                        # sentence += ". "
                prefix_list.append(sentence)
    prefix_level += 1
    return prefix_list, prefix_level

def sample_prefix(prefix_list : list[str]) -> str:
    """
    Randomly sample a prefix from a list of prefixes.

    Args:
        prefix_list (list[str]): List of prefixes to sample from.

    Returns:
        str: A randomly selected prefix or an empty string if the list is empty.
    """
    if len(prefix_list) > 0:
        prefix_index = random.randint(0, len(prefix_list)-1)
        return prefix_list[prefix_index]
    return ""

def integrate_answer(answers: list):
    """
    Integrate multiple answers to determine the most common one.

    Args:
        answers (list): List of answers to integrate.

    Returns:
        The most common answer, or None if no valid answers exist.
    """
    l = [answer for answer in answers if answer != None]
    if len(l)>0:
        return Counter(l).most_common(1)[0][0]
    return None

def extract_answer(text, type):
    """
    Extract the answer from a generated text based on the specified type.

    Args:
        text (str): The generated text containing the answer.
        type (str): The type of the answer to extract ('float' or 'str').

    Returns:
        float or str: The extracted answer, or None if not found.
    """
    if type == 'float':
        pattern = r"answer is\s*[^0-9-+]*([-+]?[0-9,]*\.?[0-9]+)"
        match = re.search(pattern, text)
        if match:
            return float(match.group(1).replace(',', ''))
        else:
            return None
    elif type == 'str':
        match = re.search(r'answer is\s*(.*)', text)
        if match:
            extracted_text = match.group(1)
            # Remove punctuation from the extracted text
            cleaned_text = re.sub(r'[^\w\s]', '', extracted_text)
            return cleaned_text.strip()
        else:
            return None