import re
from string import ascii_uppercase

def remove_letters(input_string):
    # Replace any letter from a to z or A to Z with an empty string
    return re.sub(r'[a-zA-Z]', '', input_string)

def parse_cot_explanation_deprecated(response):
    """
    Parse CoT explanation to steps.
    """
    response = response[:response.find("Final Answer:")].strip()
    cot_steps = []
    for sentence in re.split("\n", response):
        if "Step" in sentence:
            cot_steps.append(sentence)
    return cot_steps

def parse_cot_explanation(response_text):
    """
    Parse CoT explanation to extract structured steps using regular expressions, 
    handling cases where the first step might be truncated and not explicitly labeled.
    """
    # Remove the final answer and any text after it
    response_text = response_text[:response_text.find("Final Answer:")].strip()

    # Define the regex pattern for steps: Look for "Step" followed by any number and a colon
    step_pattern = r"Step \d+:"
    
    # Find all matches and their starting indices
    steps = [(match.start(), match.group()) for match in re.finditer(step_pattern, response_text)]
    
    # Extract the text corresponding to each step
    cot_steps = []
    
    for i in range(len(steps)):
        start_index = steps[i][0]
        # Check if it's the last match to handle the end of the string
        if i < len(steps) - 1:
            end_index = steps[i+1][0]
        else:
            end_index = len(response_text)
        # Append the step text removing the step number at the start
        step_text = response_text[start_index:end_index].strip()
        cot_steps.append(step_text)
    
    return cot_steps


def parse_mcq_answer(response, final_answer_str):
    answer = None
    if final_answer_str in response:
        answer = response.split(final_answer_str)[1].strip()[0]
    else:
        if "Final Answer:" not in response:
            return None
        response = response[response.find("Final Answer:") + len("Final Answer:"):]
        for option in ascii_uppercase:
            if f"({option})" in response or f" {option}." in response:
                answer = option
                break
    return answer

def parse_tags(response, tag):
    """
    Parse answer from api response, given a tag
    e.g. <FIN> 42.1 </FIN> should return 42.1
    return answer as int or float if possible, otherwise as string
    """
    tag_start = f"<{tag}>"
    tag_end = f"</{tag}>"
    answer = response[response.find(tag_start) + len(tag_start):response.find(tag_end)].strip()
    answer = remove_letters(answer).replace('$', '').replace(',', '') # remove any letters and dollar signs
    try:
        return int(answer)
    except ValueError:  # uses the fact that e.g. int("1.1") throws a ValueError
        try:
            return float(answer)
        except ValueError:
            return answer
        
def get_answer_token_idx(tokens_list, tag=None):
    """
    Get the token index position of answer in a list of tokens from the api response
    tokens_list: list of strings (tokens)
    tag: string, tag to search for
    parse_phrase: string, phrase to search for
    """
    token_lengths = [len(token) for token in tokens_list]
    answer_idx = get_answer_idx(tokens_list, tag) if tag else parse_number_from_string(''.join(tokens_list))[1]
    token_idx = 0
    while answer_idx >= sum(token_lengths[:token_idx+1]):
        token_idx += 1
    return token_idx

def get_answer_idx(answer, tag):
    """
    Get idx of start of answer in unparsed answer string
    answer: unparsed answer string or unparsed list of strings
    returns int, position of answer in string
    """
    tokens_str = ''.join(answer) if type(answer)==list else answer
    return tokens_str.find(f"<{tag}>") + len(f"<{tag}>")
    
def parse_number_from_string(s):
    # Use regular expression to find numbers with optional commas and periods
    # This pattern will match numbers like '1,234', '56.78', '123', etc.
    match = re.search(r'(\d{1,3}(?:,\d{3})*|\d+)(\.\d+)?', s)
    if match:
        start_index = match.start()
        # Remove commas and convert to float or int
        number_str = match.group().replace(',', '')
        if '.' in number_str:
            return float(number_str), start_index
        else:
            return int(number_str), start_index
    return None, 0