import json
import uuid

from sentence_transformers import SentenceTransformer

from memory.schemas import WorkflowExecutionEntry, SearcherExecutionMemory, BrowserExecutionMemory

from agents.state import OverallState
from agents.utils import get_user_question

class MemoryManager:
    def __init__(self, embedding_model: str = "all-MiniLM-L6-v2"):
        self.workflow_execution_memory = []
        self.embedding_model = SentenceTransformer(embedding_model)

    def load_workflow_execution_memory(self, workflow_execution_memory_path: str):
        with open(workflow_execution_memory_path, "r") as f:
            new_workflow_execution_memory_data = json.load(f)

        # Validate that the loaded data is a list
        if not isinstance(new_workflow_execution_memory_data, list):
            raise ValueError("New workflow execution memory must be a list")
        
        # Convert dictionaries to WorkflowExecutionEntry objects
        new_workflow_execution_memory = []
        for entry_data in new_workflow_execution_memory_data:
            try:
                entry = WorkflowExecutionEntry.model_validate(entry_data)
                new_workflow_execution_memory.append(entry)
            except Exception as e:
                raise ValueError(f"Failed to parse workflow execution memory entry: {e}")
        
        for entry in new_workflow_execution_memory:
            self._update_workflow_execution_memory(entry)

    def retrieve_workflow_execution_memory_by_execution_id(self, execution_id: str) -> WorkflowExecutionEntry:
        for entry in self.workflow_execution_memory:
            if entry.execution_id == execution_id:
                return entry
        return None
    
    def retrieve_workflow_execution_memory_by_embedding(self, question: str, sub_question: str, alpha: float = 0.5, topk: int = 5) -> WorkflowExecutionEntry:
        """
        Retrieve the workflow execution memory by the embedding of the original question and the sub-question.
        We use alpha to balance the similarity between the original question and the sub-question.
        """

        # embed the query question and sub-question
        question_embedding = self.embedding_model.encode(question)
        sub_question_embedding = self.embedding_model.encode(sub_question)

        # maintain a dict of similarity scores to avoid duplicate calculation
        question_similarity_scores = {}
        sub_question_similarity_scores = {}

        # maintain a list of top entries with their weighted scores
        top_entries = []

        # begin calculation
        for entry in self.workflow_execution_memory:
            # first check if the similarity score between the query question and the entry's question is already calculated
            if entry.question in question_similarity_scores:
                question_similarity_score = question_similarity_scores[entry.question]
            else:
                # calculate the similarity score and add to the dict
                question_similarity_score = self.embedding_model.similarity(question_embedding, entry.question_embedding).item()
                question_similarity_scores[entry.question] = question_similarity_score
                
            if entry.sub_question in sub_question_similarity_scores:
                sub_question_similarity_score = sub_question_similarity_scores[entry.sub_question]
            else:
                # calculate the similarity score and add to the dict
                sub_question_similarity_score = self.embedding_model.similarity(sub_question_embedding, entry.sub_question_embedding).item()
                sub_question_similarity_scores[entry.sub_question] = sub_question_similarity_score
                
            # calculate the weighted sum of the similarity scores
            weighted_sum = alpha * question_similarity_score + (1 - alpha) * sub_question_similarity_score

            # create a tuple with entry and its weighted score for easier comparison
            entry_with_score = (entry, weighted_sum)

            if len(top_entries) < topk:
                top_entries.append(entry_with_score)
            else:
                # if the length of the top entries is greater than or equal to the topk, we need to compare the weighted sum with the smallest weighted sum in the top entries
                min_entry = min(top_entries, key=lambda x: x[1])
                if weighted_sum > min_entry[1]:
                    top_entries.remove(min_entry)
                    top_entries.append(entry_with_score)
        
        # sort by weighted sum in descending order and return only the entries
        top_entries.sort(key=lambda x: x[1], reverse=True)
        return [entry[0] for entry in top_entries]
    
    def _update_workflow_execution_memory(self, workflow_execution_memory: WorkflowExecutionEntry):
        """
        Update the workflow execution memory.
        """
        for entry in self.workflow_execution_memory:
            if entry.execution_id == workflow_execution_memory.execution_id:
                self.workflow_execution_memory.remove(entry)
                break
        self.workflow_execution_memory.append(workflow_execution_memory)

    def _extract_workflow_execution_memory(self, before_state: OverallState, after_state: OverallState, workflow_str: str) -> WorkflowExecutionEntry:
        """
        Extract the workflow execution memory from the before and after states.
        """
        # Extract workflow-level information
        execution_id = str(uuid.uuid4())
        question = get_user_question(after_state['messages'])
        sub_question = after_state['current_sub_question']
        # before_summary = before_state['current_summary']
        # summary = after_state['current_summary']
        # can_answer_sub_question = after_state['sub_verified']
        # can_answer_question = after_state['final_verified']

        before_summary = before_state.get('current_summary', '')
        summary = after_state.get('current_summary', '')
        can_answer_sub_question = after_state.get('sub_verified', False)
        can_answer_question = after_state.get('final_verified', False)

        # Extract searcher-level information
        # before_search_count = before_state['searcher_state']['search_count']
        # after_search_count = after_state['searcher_state']['search_count']
        # before_used_keywords = before_state['searcher_state']['used_keywords']
        # after_used_keywords = after_state['searcher_state']['used_keywords']
        # before_search_results = before_state['searcher_state']['search_results']
        # after_search_results = after_state['searcher_state']['search_results']
        before_search_count = before_state.get('searcher_state', {}).get('search_count', 0)
        after_search_count = after_state.get('searcher_state', {}).get('search_count', 0)
        before_used_keywords = before_state.get('searcher_state', {}).get('used_keywords', [])
        after_used_keywords = after_state.get('searcher_state', {}).get('used_keywords', [])
        # before_search_results = before_state.get('searcher_state', {}).get('search_results', [])
        after_search_results = after_state.get('searcher_state', {}).get('search_results', [])
        new_keywords_added = list(set(after_used_keywords) - set(before_used_keywords))
        # Find new search results by comparing URLs (since SearchResult objects are not hashable)
        # before_urls = {result.url for result in before_search_results} if before_search_results else set()
        new_search_results = [result for result in after_search_results]
        searcher_execution_memory = SearcherExecutionMemory(
            search_count=after_search_count - before_search_count,
            new_keywords_added=new_keywords_added,
            new_search_results=new_search_results
        )

        # Extract browser-level information
        # before_visit_count = before_state['browser_state']['visit_count']
        # after_visit_count = after_state['browser_state']['visit_count']
        # before_visited_urls = before_state['browser_state']['visited_urls']
        # after_visited_urls = after_state['browser_state']['visited_urls']
        # before_references = before_state['browser_state']['references']
        # after_references = after_state['browser_state']['references']
        before_visit_count = before_state.get('browser_state', {}).get('visit_count', 0)
        after_visit_count = after_state.get('browser_state', {}).get('visit_count', 0)
        before_visited_urls = before_state.get('browser_state', {}).get('visited_urls', [])
        after_visited_urls = after_state.get('browser_state', {}).get('visited_urls', [])
        # before_references = before_state.get('browser_state', {}).get('found_references', [])
        after_references = after_state.get('browser_state', {}).get('found_references', [])
        new_visited_urls = list(set(after_visited_urls) - set(before_visited_urls))
        # Find new references by comparing URLs (since Reference objects are not hashable)
        # before_ref_urls = {ref.url for ref in before_references} if before_references else set()
        new_references = [ref for ref in after_references]
        browser_execution_memory = BrowserExecutionMemory(
            visit_count=after_visit_count - before_visit_count,
            new_visited_urls=new_visited_urls,
            new_references_found=new_references
        )

        # embed the question and sub-question
        question_embedding = self.embedding_model.encode(question)
        sub_question_embedding = self.embedding_model.encode(sub_question)

        # construct the workflow execution memory
        workflow_execution_memory = WorkflowExecutionEntry(
            execution_id=execution_id,
            question=question,
            sub_question=sub_question,
            before_summary=before_summary,
            summary=summary,
            can_answer_sub_question=can_answer_sub_question,
            can_answer_question=can_answer_question,
            workflow=workflow_str,
            searcher_execution_memory=searcher_execution_memory,
            browser_execution_memory=browser_execution_memory,
            question_embedding=question_embedding,
            sub_question_embedding=sub_question_embedding
        )

        return workflow_execution_memory

    def update_workflow_execution_memory(self, before_state: OverallState, after_state: OverallState, workflow_str: str):
        """
        Update the workflow execution memory after a workflow execution.
        """
        workflow_execution_memory = self._extract_workflow_execution_memory(before_state, after_state, workflow_str)
        self._update_workflow_execution_memory(workflow_execution_memory)

    def save_workflow_execution_memory(self, workflow_execution_memory_path: str = "workflow_execution_memory.json"):
        def serialize_pydantic(obj):
            """Custom serializer for Pydantic models and other objects"""
            if hasattr(obj, 'dict'):  # Pydantic model
                return obj.dict()
            elif hasattr(obj, '__dict__'):  # Regular object
                return obj.__dict__
            elif isinstance(obj, (list, tuple)):
                return [serialize_pydantic(item) for item in obj]
            elif isinstance(obj, dict):
                return {k: serialize_pydantic(v) for k, v in obj.items()}
            else:
                # For primitive types (int, str, bool, None), return as-is
                # Only convert to string as last resort for complex objects
                if isinstance(obj, (int, float, str, bool, type(None))):
                    return obj
                else:
                    return str(obj)

        with open(workflow_execution_memory_path, "w") as f:
            json.dump(self.workflow_execution_memory, f, indent=2, default=serialize_pydantic)