import json
import os
from typing import List, Dict, Any
from openai import OpenAI
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()


class memory_extraction:
    """Intelligent information extraction using LLM"""

    def __init__(
            self,
            provider: str = "openai",
            api_key: str = os.getenv("memory_llm_api_key"),
            model: str = os.getenv('memory_llm_model'),
            base_url: str = os.getenv("memory_llm_base_url"),
            embedding_model: str = os.getenv('embedding_llm_model'),
            storage_path: str = os.getenv('memory_storage_path'),
    ):
        """
        Initialize LLM extractor

        Args:
            provider: LLM provider
            api_key: API key
            model: Model name
            base_url: Custom API endpoint (optional)
        """
        self.provider = provider
        self.api_key = api_key
        self.base_url = base_url
        self.model = model
        self.embedding_model = embedding_model
        self.storage_path = Path(storage_path)

        # Initialize client
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
        self.embedding_client = OpenAI(api_key=os.getenv('embedding_llm_api_key'),
                                       base_url=os.getenv('embedding_llm_base_url'))

    def extract_high_level_goal(self, action_sequence: List[Dict[str, Any]]) -> str:
        if not action_sequence:
            raise ValueError("action_sequence cannot be empty")


        sequence_description = self._format_action_sequence(action_sequence)
        # print(sequence_description)


        prompt = f"""Analyze the following GUI action sequence and generate a concise, high-level description of the overall goal."""

        try:
            response = self._call_llm(prompt)
            # print(response)
            response = self._clean_json_response(response)
            response = json.loads(response)
            return response["high_level_goal"]

        except Exception as e:
            raise RuntimeError(f"Failed to generate action context: {str(e)}") from e

    def extract_low_level_instructions(self, action_sequence: List[Dict[str, Any]], high_level_goal: str) -> List[str]:
        """Generate low-level operation instructions using LLM"""

        if not action_sequence:
            raise ValueError("action_sequence cannot be empty")
        if not high_level_goal:
            raise ValueError("high_level_goal cannot be empty")

        sequence_description = self._format_action_sequence(action_sequence)

        prompt = f"""Given the following high-level goal and detailed action sequence, convert it into clear step-by-step instructions.."""

        # Parse JSON response
        try:
            # 调用OpenAI API
            response = self._call_llm(prompt)
            response = self._clean_json_response(response)

            return response

        except json.JSONDecodeError:
            # Fallback: split by lines
            lines = [line.strip() for line in response.split('\n') if line.strip()]
            return [line.strip('- ').strip('"').strip("'") for line in lines]

    def extract_task_state(self, goal: str, high_level_goal: str, low_level_instructions: List[str]) -> str:
        if not goal:
            raise ValueError("goal cannot be empty")
        if not high_level_goal:
            raise ValueError("high_level_goal cannot be empty")
        if not low_level_instructions:
            raise ValueError("low_level_instructions cannot be empty")

        DEFAULT_TASK_CATEGORIES = [
            "Media Recording",
            "Media Playback",
            "Communication",
            "File Management",
            "Settings & Configuration",
            "Data Entry & Forms",
            "Navigation & Browsing",
            "Social Media",
            "Productivity & Tools",
            "Gaming",
            "Shopping & E-commerce",
            "Education & Learning",
            "Health & Fitness",
            "System Operation",
            "Other"
        ]
        custom_categories = DEFAULT_TASK_CATEGORIES


        if isinstance(low_level_instructions, list):
            instructions_text = "\n".join([
                f"{i}. {instruction}"
                for i, instruction in enumerate(low_level_instructions, 1)
            ])
        else:
            instructions_text = low_level_instructions
        # print(instructions_text)


        prompt = f"""Analyze the following task information and classify it into one of the predefined categories."""

        try:

            response = self._call_llm(prompt)
            response = self._clean_json_response(response)
            response = json.loads(response)
            task_category = response["task_category"]


            if task_category not in custom_categories:
                print(f"Warning: LLM returned unexpected category '{task_category}'. Using 'Other'.")
                task_category = "Other"

            return task_category

        except json.JSONDecodeError as e:
            raise RuntimeError(f"Failed to parse API response: {str(e)}") from e
        except Exception as e:
            raise RuntimeError(f"Failed to extract task state: {str(e)}") from e

    def store_embedding(self, high_level_goal: str) -> list[float]:

        if not high_level_goal:
            raise ValueError("goal cannot be empty")

        try:
            embedding = self._get_embedding(high_level_goal)

            return embedding
        except Exception as e2:
            print(f"Failed to process memory {high_level_goal}: {e2}")

    def _get_embedding(self, text: str) -> List[float]:

        text = text.replace("\n", " ").strip()

        if not text:
            raise ValueError("Text cannot be empty")

        try:
            response = self.embedding_client.embeddings.create(
                model=self.embedding_model,
                input=text
            )
            return response.data[0].embedding

        except Exception as e:
            raise RuntimeError(f"Failed to get embedding: {str(e)}") from e

    def _format_action_sequence(self, action_sequence: List[Dict[str, Any]]) -> str:

        formatted_steps = []

        for step in action_sequence:

            if not step or not isinstance(step, dict):
                continue

            step_num = step.get("step_number", "?")
            action = step.get("action")
            reason = step.get("reason", "")
            ui_elements = step.get("before_ui_elements", [])


            if action is None:
                action = {}


            if not isinstance(ui_elements, list):
                ui_elements = []


            action_type = action.get("action_type", "unknown")
            action_index = action.get("index")


            target_element = ""
            if action_index is not None and isinstance(action_index, int) and 0 <= action_index < len(ui_elements):
                element = ui_elements[action_index]

                if element and isinstance(element, dict):
                    element_text = element.get("text") or element.get("content_description") or "unnamed element"
                    target_element = f" on '{element_text}'"


            if action_type == "open_app":
                app_name = action.get("app_name", "unknown app")
                action_desc = f"Open app: {app_name}"
            elif action_type == "click":
                action_desc = f"Click{target_element}"
            elif action_type == "input_text":
                text = action.get("text", "")
                action_desc = f"Input text: '{text}'{target_element}"
            elif action_type == "scroll":
                direction = action.get("direction", "unknown")
                action_desc = f"Scroll {direction}"
            elif action_type == "swipe":
                direction = action.get("direction", "")
                action_desc = f"Swipe {direction}"
            elif action_type == "back":
                action_desc = f"Press back button"
            elif action_type == "home":
                action_desc = f"Press home button"
            elif action_type == "enter":
                action_desc = f"Press enter"
            else:
                action_desc = f"{action_type}{target_element}"


            step_desc = f"Step {step_num}: {action_desc}"
            if reason:
                step_desc += f"\n  Reason: {reason}"


            key_elements = [
                elem for elem in ui_elements
                if elem and isinstance(elem, dict) and (elem.get("is_clickable") or elem.get("is_editable"))
            ]
            if key_elements and len(key_elements) <= 10:
                elements_desc = ", ".join([
                    elem.get("text") or elem.get("content_description") or "unnamed"
                    for elem in key_elements[:5]
                ])
                step_desc += f"\n  Key UI elements: {elements_desc}"

            formatted_steps.append(step_desc)

        return "\n\n".join(formatted_steps)

    def _clean_json_response(self, response: str) -> str:

        cleaned = response.strip()

        # Extract content after <answer> tag if present
        answer_start = cleaned.find('<answer>')
        if answer_start != -1:
            cleaned = cleaned[answer_start + 8:]
            answer_end = cleaned.find('</answer>')
            if answer_end != -1:
                cleaned = cleaned[:answer_end]

        # Remove markdown code blocks
        if cleaned.startswith("```"):
            parts = cleaned.split("```")
            if len(parts) >= 2:
                cleaned = parts[1]
                if cleaned.startswith("json"):
                    cleaned = cleaned[4:]

        return cleaned.strip()

    def _call_llm(self, prompt: str, max_tokens: int = 2500) -> str:
        """Call LLM API"""
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are an expert at analyzing GUI interaction sequences and summarizing user goals. You provide concise, accurate descriptions of what users are trying to accomplish."
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                max_tokens=max_tokens,
                temperature=0.1
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"OpenAI API call error: {e}")
            return ""
