import re
import json

def parse_reason_and_action_alfworld(text):
    reason_pattern = r"REASON:\s*(.*?)\s*ACTION:"
    action_pattern = r"ACTION:\s*([^\n]+)"

    reason_match = re.search(reason_pattern, text, re.DOTALL)
    action_match = re.search(action_pattern, text)

    reason = reason_match.group(1).strip() if reason_match else text
    action = action_match.group(1).strip() if action_match else ""

    #Clean up action to move to lower case and remove any random characters
    action = action.lower()
    action = re.sub(r'[^a-z0-9 /]', '', action)

    return reason, action

def parse_reason_and_action_webshop(text):
    reason_pattern = r"REASON:\s*(.*?)\s*ACTION:"
    action_pattern = r"ACTION:\s*([^\n]+)"

    reason_match = re.search(reason_pattern, text, re.DOTALL)
    action_match = re.search(action_pattern, text)

    reason = reason_match.group(1).strip() if reason_match else text
    action = action_match.group(1).strip() if action_match else ""

    return reason, action

def parse_reason_and_action_intercode(text):
    text = text.replace("\\n", "\n")
    reason_pattern = r"REASON:\s*(.*?)\s*ACTION:"
    
    # Two patterns for action: one with Markdown formatting and one without
    action_pattern_with_markdown = r"ACTION:\s*```(?:[a-zA-Z]*)?\s*(.*?)\s*```"
    action_pattern_without_markdown = r"ACTION:\s*(.+)"

    reason_match = re.search(reason_pattern, text, re.DOTALL)
    
    # Try matching the action with Markdown first
    action_match = re.search(action_pattern_with_markdown, text, re.DOTALL)
    
    # If no Markdown match, try the simpler non-Markdown pattern
    if not action_match:
        action_match = re.search(action_pattern_without_markdown, text, re.DOTALL)
    
    reason = reason_match.group(1).strip() if reason_match else text
    action = action_match.group(1).strip() if action_match else ""

    return reason, action

def parse_reason_and_action_intercode_sql(text):
    reason, action = parse_reason_and_action_intercode(text)
    action = action.split('\n', 1)[0]
    # Further strip everything after the first semicolon (if any) but keep the semicolon
    if ';' in action:
        action = action.split(';', 1)[0] + ';'
    return reason, action

# def parse_search_action_webshop(action):
#     # Regex to capture the format search followed by optional space and any text not enclosed in brackets
#     pattern = r'^search\s*(?!\[)(.*)$'
    
#     # This function checks the pattern and formats accordingly
#     def replace_func(match):
#         content = match.group(1).strip()
#         return f"search[{content}]"
    
#     # Replace using the pattern and function
#     return re.sub(pattern, replace_func, action)

def preprocess_json_string(response):
    # Escape problematic characters
    response = response.replace("\\", "\\\\")  # escape backslashes
    response = re.sub(r'\s+', ' ', response)  # normalize spaces
    response = re.sub(r',\s*}', '}', response)  # remove trailing commas before closing braces
    response = re.sub(r',\s*]', ']', response)  # remove trailing commas before closing brackets
    corrected_response = re.sub(r'"\s*"trajectory"', '", "trajectory"', response)
    return corrected_response

def parse_corrected_trajectory(response):
    try:
        json_str_match = re.search(r'```json\n(\{.*?\})\n```', response, re.DOTALL)

        if json_str_match is None:
            raise ValueError("No JSON content found")
        json_str = json_str_match.group(1)

        json_str = preprocess_json_string(json_str)
        parsed_dict = json.loads(json_str)
        corrected_trajectory = parsed_dict['trajectory']

        if not isinstance(corrected_trajectory, list) or not all(isinstance(item, dict) for item in corrected_trajectory):
            raise ValueError("corrected_trajectory is not a list of dicts")
    except (json.JSONDecodeError, KeyError, ValueError) as e:
        corrected_trajectory = None
        print(f"An error occurred: {e}")
    
    return corrected_trajectory

def preprocess_json_string(response):
    response = response.replace('\\"', "'") # Replace escaped double quotes with single quotes
    response = response.replace("\\", "\\\\")  # escape backslashes
    response = re.sub(r'\s+', ' ', response)  # normalize spaces
    response = re.sub(r',\s*}', '}', response)  # remove trailing commas before closing braces
    response = re.sub(r',\s*]', ']', response)  # remove trailing commas before closing brackets
    return response

def parse_json(response):
    # Use regex to capture the JSON content between ```json and ```
    json_string = re.search(r'```json\n([\s\S]*?)\n```', response)
    if json_string:
        json_data = json_string.group(1)  # Extract the JSON string
        json_data = preprocess_json_string(json_data)  # Preprocess to fix common JSON issues
        try:
            # Parse the JSON string into a Python object
            parsed_data = json.loads(json_data)
            return parsed_data
        except json.JSONDecodeError as e:
            print("Error decoding JSON:", e)
            return None
    else:
        print("No JSON data found")
        return None

def extract_task_from_observation_alfworld(observation: str) -> str:
    match = re.search(r"(?<=Your task is to: ).*", observation)
    task = match.group() if match else ""
    return task

def substitute_placeholders(config, template, value):
    if isinstance(config, dict):
        for key, sub_config in config.items():
            config[key] = substitute_placeholders(sub_config, template, value)
    elif isinstance(config, list):
        config = [substitute_placeholders(item, template, value) for item in config]
    elif isinstance(config, str):
        config = config.replace(template, value)
    return config

def parse_corrected_reason_and_action_alfworld(text):
    reason_pattern = r"CORRECTED_REASON:\s*(.*?)\s*CORRECTED_ACTION:"
    action_pattern = r"CORRECTED_ACTION:\s*([^\n]+)"

    reason_match = re.search(reason_pattern, text, re.DOTALL)
    action_match = re.search(action_pattern, text)

    reason = reason_match.group(1).strip() if reason_match else text
    action = action_match.group(1).strip() if action_match else ""

    #Clean up action to move to lower case and remove any random characters
    action = action.lower()
    action = re.sub(r'[^a-z0-9 /]', '', action)

    return reason, action