import re
import string
import json
def extract_answer(response_text: str, reasoning_type: str) -> bool:
    model_choice = None
    search_text = response_text
    if reasoning_type == "numerical":
        patterns = [
            r"Answer:\s*([^a-zA-Z]*)",
            r"Answer:\s*(.*)",
            r"(?:the|my)?\s*(?:final\s*)?(?:answer|solution|result)\s*(?:is)?\s*[:=]\s*\$?(-?\d[\d,]*\.?\d+)",
            r"\\?boxed\{([^}]+)\}",
            r"the\s+(?:final\s+)?answer\s+is\s+\$?(-?\d[\d,]*\.?\d+)",
            r"(-?\d[\d,]*\.?\d+)\s*\.?\s*$",
        ]
    elif reasoning_type == "multiple_choice":
        patterns = [
            r"Answer:\s*(.*)",
            r"(?:the|my)?\s*(?:correct\s*)?(?:answer|option|choice|solution)\s*(?:is)?\s*[:=]?\s*\(?\s*([0-3])\s*\)?(?:\.|\))?"
        ]
    for pattern in patterns:
        matches = re.findall(pattern, response_text, re.IGNORECASE)
        if matches:
            model_choice = matches[-1].replace(",", "").rstrip('.')
            answer = "".join([char for char in model_choice if char.isdigit()])
            if len(answer) < 1:
                continue
            return answer

    # numerical answers 
    if reasoning_type == "numerical":
        num_regex = r"(-?\d[\d,]*(?:\.\d+)?)"
        matches = re.findall(num_regex, response_text)
        if matches:
            model_choice = matches[-1].replace(",", "").rstrip('.')
            return model_choice
    # multiple-choice answers
    elif reasoning_type == "multiple_choice":
        mc_regex = r'\b([A-Z])\b'
        matches = re.findall(mc_regex, search_text)
        if matches:
            model_choice = matches[-1]
            return model_choice
    return None

def extract_gt(answer: str) -> str:
    if "###" in answer:
        potential_answer = answer.split("### ")[-1].strip()
        clean_answer = potential_answer.replace(',', '')
        return clean_answer
    matches = re.findall(r"(-?\d[\d,]*\.?\d*)", answer)
    if matches:
        last_match = matches[-1]
        clean_answer = last_match.replace(",", "")
        return clean_answer
    mc_matches = re.findall(r'\b([A-Z])\b', answer)
    if mc_matches:
        return mc_matches[-1]
    return None

def extract_phrase(model_output: str) -> str:
    pattern = r"Phrase:\s*(.*?)(?:\n|$)"
    match = re.search(pattern, model_output, re.IGNORECASE)
    if match:
        return match.group(1).strip()
    return None

def extract_alist(model_output: str) -> str:
    pattern = r"Alist:\s*(.*)"
    match = re.search(pattern, model_output, re.IGNORECASE | re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

def extract_alist_nl_pairs(model_output: str) -> tuple[str, str]:
    pattern = r"Phrase:\s*(.*?)\n\s*Alist:\s*(.*)"
    match = re.search(pattern, model_output, re.IGNORECASE | re.DOTALL)
    if match:
        phrase = match.group(1).strip()
        alist = match.group(2).strip()
        return phrase, alist
    print("No match found for alist and phrase in model output.")
    return None, None

def extract_alist_nl_pairs_from_json(json_string: str) -> tuple[str, dict]:
    lines = json_string.strip().split('\n')
    cleaned_lines = [line for line in lines if not line.strip().startswith('```')]
    cleaned_json_string = '\n'.join(cleaned_lines)
    try: 
        data = json.loads(cleaned_json_string)
        phrase = data.get("Phrase")
        if phrase is None:
            phrase = data.get("phrase") 
        alist = data.get("Alist")
        if alist is None:
            alist = data.get("alist")
        return phrase, alist
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        return None, None

def extract_generated_data(prompt_id: str, model_output: str, natural_language: str = None) -> tuple[str, dict]:
    phrase = None
    alist = None
    if prompt_id == "natural_language":
        phrase = extract_phrase(model_output)
    elif prompt_id == "alist_from_nl":
        phrase = natural_language
        alist = extract_alist(model_output)
    elif prompt_id == "alist_nl_pairs" or prompt_id == "alist_nl_pairs_subject_specific":
        phrase, alist = extract_alist_nl_pairs_from_json(model_output)
        if phrase == "null" and alist == "null":
            phrase, alist = extract_alist_nl_pairs(model_output)
    else:
        raise ValueError(f"No extraction function found for prompt_id: {prompt_id}")
    if type(phrase) is dict:
        cleaned_phrase = phrase.get("phrase", "")
    else:
        cleaned_phrase = phrase if phrase else None
    return cleaned_phrase, json.dumps(alist, indent=2) if alist and prompt_id == 'alist_nl_pairs' else alist if alist else None