import json
import re
import ast
import typing
from typing import Dict, Any

from langchain_core.messages import AIMessage

from agentS.consts import STATELESS_ACTIONS, ANSWER_KEYS, NO_BID_ACTIONS, ARGS_KEY, STATE_KEY, ID_KEY, EMPTY_STATE_ID, \
    UNKNOWN_ID, FINAL_ANSWER


class Agents_Output_Handlers:
    @staticmethod
    def handle_ranker_agent(state, result, name):
        cleaned_result = result.content.replace('json', "").replace('```', "")
        cleaned_json_data = re.sub(r'//.*', '', cleaned_result)

        # Method 1: Direct JSON parsing
        try:
            parsed_result = json.loads(cleaned_json_data)
            state['elements'] = parsed_result.get('Relevant Elements', parsed_result.get(list(state.keys())[0]))
            return {**state, "messages": [result], "sender": name}
        except json.JSONDecodeError:
            pass  # Move to next method if this fails

        # Method 2: Evaluate as Python dictionary
        try:
            parsed_result = ast.literal_eval(cleaned_json_data)
            if isinstance(parsed_result, dict):
                state['elements'] = parsed_result.get('Relevant Elements', parsed_result.get(list(state.keys())[0]))
                return {**state, "messages": [result], "sender": name}
        except (SyntaxError, ValueError):
            pass  # Move to next method if this fails

        # Method 3: Custom parsing for the specific format
        try:
            # Split the string into sections
            sections = re.split(r'\n\s*\n', cleaned_json_data)
            parsed_result = {}

            for section in sections:
                if ':' in section:
                    key, value = section.split(':', 1)
                    key = key.strip()
                    value = value.strip()

                    if key in ["Relevant Elements", "Element Priorities"]:
                        parsed_result[key] = json.loads(value)
                    elif key in ["Rationale", "Next Steps Suggestion"]:
                        parsed_result[key] = value.strip('"')

            if parsed_result:
                state['elements'] = parsed_result.get('Relevant Elements', parsed_result.get(list(state.keys())[0]))
                return {**state, "messages": [result], "sender": name}
        except Exception as e:
            print(f'\n Error in custom parsing: {str(e)}')

        # If all methods fail, return the original string
        print(f'\n Error transforming to JSON: {cleaned_json_data}')
        state['elements'] = cleaned_json_data
        return {**state, "messages": [result], "sender": name}

    @staticmethod
    def handle_action_agent(state, result, name):
        if not result.tool_calls:
            return {**state, "messages": [result], "sender": name}

        for call in result.tool_calls:
            if call["name"] in NO_BID_ACTIONS:
                if 'human_in_the_loop' in call["name"]:
                    for key in ANSWER_KEYS:
                        # Added check for 'state' key and create if not present
                        if 'state' not in call["args"]:
                            call["args"]['state'] = {}

                        if key in call["args"]:
                            call["args"]['state'][key] = call["args"].pop(key)

                if call["name"] == "update_policy":
                    try:
                        if 'state' in call['args']:
                            reason_key = 'reason' if 'reason' in call['args']['state'] else \
                                next(iter(call['args']['state']), None)
                            if reason_key:
                                state['update_policy_reason'] = call['args']['state'][reason_key]
                            else:
                                print(
                                    f"Warning: No valid reason key found in call['args']['state']: {call['args']['state']}")
                        else:
                            reason_key = 'reason' if 'reason' in call['args'] else \
                                next(iter(call['args']), None)
                            if reason_key:
                                state['update_policy_reason'] = call['args'][reason_key]
                            else:
                                print(f"Warning: No valid reason key found in call['args']: {call['args']}")
                    except Exception as e:
                        print(f"Error in update_policy handling: {str(e)}")
                        print(f"call['args']: {call['args']}")
                continue

            call = Agents_Output_Handlers.update_call_with_fixed_element(call)

            action_state = call["args"].get("state", call["args"])

            if call["name"] == "answer":
                result.content = f"FINAL ANSWER \n {action_state}"
                return {**state, "messages": [result], "sender": "END"}

            element_predicted = Agents_Output_Handlers.extract_element_predicted(action_state)

            if not Agents_Output_Handlers.is_valid_element(element_predicted):
                error_message = f"ERROR: You have predicted to perform {call['name']} of {element_predicted} as the next action. Please provide the element ID!"
                state['observation'] = error_message
                if 'pointer_env' in state:
                    pointer_env = state['pointer_env']
                    pointer_env.feedback.append(error_message)
                if len(result.tool_calls) == 1:
                    return {**state, "messages": [AIMessage(content=state['observation'], name="call_tool")],
                            "sender": "ActionAgent"}
                else:
                    # Delete this call
                    result.tool_calls.remove(call)

            # Process valid element_predicted
            if isinstance(element_predicted, str):
                # if element_predicted.isdigit():
                #     element_predicted = int(element_predicted)
                # elif
                if 'FINAL ANSWER' in element_predicted:
                    state['observation'] = element_predicted
                    return {**state, "messages": [AIMessage(content=element_predicted, name="END")], "sender": "END"}

        # Verify that there is only one call to click function, if not take the first one
        # for i, call in enumerate(result.tool_calls):
        #     if call['name'] == 'answer':
        #         result.tool_calls = [call]  # Return only the answer call
        #         break
        #     if call["name"] == "click":
        #         result.tool_calls = result.tool_calls[:i + 1]
        #         break

        return {**state, "messages": [result], "sender": name}

    @staticmethod
    def extract_element_predicted(action_state: typing.Union[Dict[str, Any], Any]) -> Any:
        if isinstance(action_state, dict):
            try:
                return next(iter(action_state.values()))
            except StopIteration:
                return action_state
        return action_state

    @staticmethod
    def handle_planner_agent(state, result, name):
        state['policy'] = result.content

        feedback = {
            "action": "update_policy",
            "status": "success",
            "message": "Successfully updated the policy.",
            "update_reason": state.get("update_policy_reason", "No reason provided")
        }

        state['env'].feedback.append(f'Action_feedback: {json.dumps(feedback)}')

        state['update_policy_reason'] = ""

        return {
            **state,
            "messages": [result],
            "sender": name
        }

    @staticmethod
    def update_call_with_fixed_element(call: Dict[str, Any]) -> Dict[str, Any]:
        try:
            if ARGS_KEY in call and isinstance(call[ARGS_KEY], dict):
                state = call[ARGS_KEY].get(STATE_KEY, {})

                if not isinstance(state, dict):
                    call[ARGS_KEY][STATE_KEY] = {ID_KEY: str(state)}
                elif not state:
                    call[ARGS_KEY][STATE_KEY] = {ID_KEY: EMPTY_STATE_ID}
                elif ID_KEY not in state:
                    # Try to find an ID-like key or use the first item
                    id_key = next((k for k in state if ID_KEY in k.lower()), None)
                    if id_key:
                        state[ID_KEY] = state.pop(id_key)
                    else:
                        first_key, first_value = next(iter(state.items()), (None, None))
                        if first_key is not None:
                            state[ID_KEY] = str(first_value)
                        else:
                            state[ID_KEY] = UNKNOWN_ID

                # Ensure the ID is properly formatted
                if ID_KEY in state:
                    state[ID_KEY] = Agents_Output_Handlers.element_fixing_attempt(state[ID_KEY])

            return call
        except Exception as e:
            print(f"Error in update_call_with_fixed_element: {str(e)}")
            print(f"Call structure: {call}")
            return call  # Return the original call if we can't process it

    @staticmethod
    def is_valid_element(element: Any) -> bool:
        if isinstance(element, int):
            return True
        if isinstance(element, str):
            if element.isdigit():
                return True
            if FINAL_ANSWER in element:
                return True
            if re.match(r'^[a-zA-Z]\d+$', element):
                return True
        return False

    @staticmethod
    def element_fixing_attempt(element: Any) -> typing.Union[str, int, None]:
        if isinstance(element, (int, str)):
            return Agents_Output_Handlers.process_element(str(element))
        elif isinstance(element, dict):
            # If it's a dict, try to process the first value
            try:
                first_value = next(iter(element.values()))
                return Agents_Output_Handlers.process_element(str(first_value))
            except StopIteration:
                return None
        elif isinstance(element, list):
            # If it's a list, try to process the first item
            if element:
                return Agents_Output_Handlers.process_element(str(element[0]))
        return None

    @staticmethod
    def process_element(element: str) -> typing.Union[str, int, None]:
        # Remove any surrounding brackets, braces, or whitespace
        element = element.strip('[]{}() \t\n\r')

        # Check if it's a pure digit
        if element.isdigit():
            return int(element)

        # Check if it's in the format of a letter followed by digits
        match = re.match(r'^([a-zA-Z])(\d+)$', element)
        if match:
            return f"{match.group(1)}{match.group(2)}"

        # Check if it contains 'FINAL ANSWER'
        if 'FINAL ANSWER' in element:
            return 'FINAL ANSWER'

        # If we can't process it, return the original element
        return element
