from __future__ import annotations

import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from datetime import datetime, timezone

import numpy as np
import random

from src.utils import epsilon_greedy_select, boltzmann_select
from src.memorykit.memory import MemorySystem
from src.memorykit.note import (
    APINote,
    BaseMemoryNote,
    CodeExampleNote,
    ExperienceNote,
    NoteType,
    ExperienceType,
    ExperienceSource,
)
from src.models.agent_models import ActionType
from src.configs.memory import MemoryConfig

logger = logging.getLogger(__name__)

@dataclass(slots=True)
class RetrievalResult:
    """Structured response for memory lookups."""

    ids: List[str]
    contents: List[Any]

EMPTY = RetrievalResult(ids=[], contents=[])


class MemoryRetriever:
    """Centralizes all memory lookups used by agents.

    This keeps CodeAgent small and makes it easier to evolve retrieval logic
    (e.g. RL-aware selection, future caching, etc.) in one place.
    """

    def __init__(
        self,
        memory_system: MemorySystem,
        config: MemoryConfig,
    ):
        self.memory_system: MemorySystem = memory_system
        self.cfg: MemoryConfig = config
    
    @property
    def candidate_pool_multiplier(self)->float:
        return self.cfg.candidate_pool_multiplier
    
    @candidate_pool_multiplier.setter
    def candidate_pool_multiplier(self, value: float):
        self.cfg.candidate_pool_multiplier = value
        if self.cfg.candidate_pool_multiplier > self.cfg.max_candidate_pool_multiplier:
            self.cfg.candidate_pool_multiplier = self.cfg.max_candidate_pool_multiplier
    
    def _select_by_q_value(self, candidates: List[BaseMemoryNote], top_k: int, stage: ActionType) -> List[BaseMemoryNote]:
        """Select top_k memories from candidates using Q-value and exploration strategy.
        
        Args:
            candidates: List of candidate memories
            top_k: Number of top memories to select
            stage: Optional stage identifier ('draft' or 'optimize'). If None, defaults to 'draft'
        """
        if not candidates:
            return []
        
        # Limit candidate count to avoid excessive computation overhead
        max_candidates = min(len(candidates), top_k * 20)  # Consider at most top_k * 20 candidates
        candidates = candidates[:max_candidates]
        
        # If stage is not provided, default to draft
        if stage is None:
            raise ValueError("stage is required")
        
        logger.info(
            "[Q-value selection] Found %s candidate memories: %s, stage=%s",
            len(candidates), candidates, stage
        )
        
        # Get configuration parameters
        strategy = self.cfg.exploration_strategy
        
        # Get Q values for each candidate memory and sort
        # If concept drift adaptation is enabled, use time-weighted sort key
        if self.cfg.enable_concept_drift:
            sorted_candidates = sorted(
                candidates, 
                key=lambda x: x.concept_drift_sort_key(
                    decay_factor=self.cfg.drift_time_decay_factor,
                    recent_bonus=self.cfg.drift_recent_success_bonus,
                    recent_window_days=self.cfg.drift_recent_window_days,
                    stage=stage,
                ), 
                reverse=True
            )
            logger.info(
                "[Q-value selection] Using concept drift adaptation sorting, "
                "sorted top %s notes: %s",
                len(sorted_candidates), sorted_candidates
            )
        else:
            # Use the sort_key corresponding to the stage
            sorted_candidates = sorted(
                candidates, 
                key=lambda x: x.sort_key(stage=stage),
                reverse=True
            )
            logger.info(
                "[Q-value selection] Using standard Q-value sorting, "
                "sorted top %s notes: %s",
                len(sorted_candidates), sorted_candidates
            )
        
        # Select according to strategy
        if strategy == "epsilon_greedy":
            logger.info(
                "[Q-value selection] Using epsilon_greedy strategy, epsilon=%s, top_k=%s",
                self.cfg.epsilon, top_k
            )
            selected = epsilon_greedy_select(
                sorted_candidates, 
                top_k, 
                epsilon=self.cfg.epsilon
            )
        elif strategy == "boltzmann":
            logger.info(
                "[Q-value selection] Using boltzmann strategy, temperature=%s, default_value=%s, top_k=%s",
                self.cfg.temperature, self.cfg.q_init, top_k
            )
            selected = boltzmann_select(
                sorted_candidates,
                top_k,
                temperature=self.cfg.temperature,
                default_value=self.cfg.q_init
            )
        else:
            logger.warning("[Q-value selection] Unknown exploration strategy: %s, falling back to epsilon_greedy", strategy)
            selected = epsilon_greedy_select(
                sorted_candidates, 
                top_k, 
                epsilon=self.cfg.epsilon
            )
        
        # Record the finally selected memories and their Q values
        logger.info(
            "[Q-value selection] Selected %s memories: %s",
            len(selected), selected
        )
        
        return selected
    
    def _select_by_epsilon_greedy(self, candidates: List[BaseMemoryNote], top_k: int) -> List[BaseMemoryNote]:
        """Select top_k memories from candidates using ε-greedy strategy based on similarity score."""
        if not candidates:
            return []
        
        logger.info(
            "[Epsilon-greedy selection] Found %s candidate memories: %s",
            len(candidates), candidates
        )
        
        selected = epsilon_greedy_select(
            candidates, 
            top_k, 
            epsilon=self.cfg.epsilon
        )
        
        logger.info(
            "[Epsilon-greedy selection] Selected %s memories: %s",
            len(selected), selected
        )
        
        return selected
    
    def _select_by_random(self, candidates: List[BaseMemoryNote], top_k: int) -> List[BaseMemoryNote]:
        """Randomly select top_k memories from candidates."""
        if not candidates:
            return []
        
        logger.info(
            "[Random selection] Found %s candidate memories: %s",
            len(candidates), candidates
        )
        
        if len(candidates) <= top_k:
            selected = candidates
        else:
            selected = np.random.choice(candidates, size=top_k, replace=False).tolist()
        
        logger.info(
            "[Random selection] Selected %s memories: %s",
            len(selected), selected
        )
        
        return selected
    
    def retrieve(
        self,
        payload_filter: Dict[str, Any], 
        stage: ActionType,
        query: str = "",
        top_k: int=None,
        similarity_threshold: Optional[float] = None,
        with_similarity: bool = True,
        
    ) -> List[Any]:
        """Internal search wrapper that delegates to memory_system.
        
        Select memories based on value-based (Q-value + ε-greedy) or similarity-based search.
        """

        if with_similarity:
            if self.cfg.enable_value_driven:
                candidate_pool_size = int(top_k * self.cfg.candidate_pool_multiplier)
            elif self.cfg.experiment_violation_q:
                candidate_pool_size = int(top_k * self.cfg.max_candidate_pool_multiplier)
            else:
                candidate_pool_size = top_k

            candidate_memories = self.memory_system.search(
                query=query, 
                top_k=candidate_pool_size, 
                filter=payload_filter,
                score_threshold=similarity_threshold
            )
        else:
            candidate_memories = self.memory_system.get_by_filter(
                filter=payload_filter,
                with_vectors=False,
                with_payload=True
            )

        # If top_k is None, return all candidate memories
        if not top_k:
            top_k = len(candidate_memories)
        if self.cfg.enable_value_driven:
            # Value-based Filter: first retrieve a larger candidate pool, then select top_k using Q-value
            # Use Q-value and strategy to select top_k
            result_memories = self._select_by_q_value(candidate_memories, top_k, stage)
            logger.info(
                "[Retrieval] Value-based Filter: query=\"%r\"..., top_k=%s, candidate_pool_size=%s, exploration_strategy=%s, result_memories=%s",
                query[:50], top_k, len(candidate_memories), self.cfg.exploration_strategy, result_memories
            )
            return result_memories
        elif self.cfg.experiment_violation_q:
            # Ablation Study: violating Q-value usage in retrieval, using ε-greedy selection
            logger.info(
                "[Retrieval] Ablation Study: violating Q-value usage in retrieval, using epsilon_greedy selection, query=\"%r\"..., top_k=%s, candidate_pool_size=%s",
                query[:50], top_k, len(candidate_memories)
            )
            logger.info("[Retrieval] Ablation study: violating Q-value usage in retrieval, using epsilon_greedy selection")
            if candidate_memories and candidate_memories[0].note_type == NoteType.CODE_EXAMPLE.value and candidate_memories[0].is_correct is True:
                candidate_memories.sort(key=lambda x: x.sort_key_score_performance(), reverse=True)
            elif candidate_memories and candidate_memories[0].score is None:
                random.shuffle(candidate_memories)
                
            return self._select_by_epsilon_greedy(candidate_memories, top_k)
        
        else:
            # Without any filtering, directly return the top_k memories
            logger.info(
                "[Retrieval] Without any filtering, directly return the %s memories: %s for query: \"%r\"...",
                len(candidate_memories[:top_k]), candidate_memories[:top_k], query[:50]
            )
            return candidate_memories[:top_k]

    # ------------------------------------------------------------------ #
    # Shared public helpers
    # ------------------------------------------------------------------ #
    def find_correct_references(
        self,
        task: str,
        top_k: int,
    ) -> RetrievalResult:
        """Search for correct code examples that match the task description."""

        notes: List[CodeExampleNote] = self.retrieve(
            query=task,
            top_k=top_k,
            payload_filter={"note_type": NoteType.CODE_EXAMPLE.value, "is_correct": True},
            stage=ActionType.DRAFT,
        )
        ref_results: List[Dict[str, Any]] = []
        ref_ids: List[str] = []
        for note in notes:
            ref_ids.append(note.id)
            ref_results.append(
                {
                    "memory_id": note.id,
                    "op_name": note.op_name,
                    "arc_src": note.arc_src,
                    "code": note.code,
                    "performance": note.performance,
                    "optimized_degree_to_best": note.optimized_degree_to_best,
                    "optimized_degree_to_parent": note.optimized_degree_to_parent,
                    "optimized_degree_to_root": note.optimized_degree_to_root,
                }
            )
        return RetrievalResult(ids=ref_ids, contents=ref_results)

    def find_correct_references_by_op_name(
        self,
        op_name: str,
        top_k: int,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Search for correct code examples that match the op_name in stage optimize."""
        notes: List[CodeExampleNote] = self.retrieve(
            top_k=top_k,
            payload_filter={"$must": [
                {"note_type": NoteType.CODE_EXAMPLE.value},
                {"is_correct": True},
                {"op_name": op_name}
                ]},
            with_similarity=False,
            stage=stage,
        )

        ref_results: List[Dict[str, Any]] = []
        ref_ids: List[str] = []
        for note in notes:
            ref_ids.append(note.id)
            ref_results.append(
                {
                    "memory_id": note.id,
                    "op_name": note.op_name,
                    "arc_src": note.arc_src,
                    "code": note.code,
                    "performance": note.performance,
                    "optimized_degree_to_best": note.optimized_degree_to_best,
                    "optimized_degree_to_parent": note.optimized_degree_to_parent,
                    "optimized_degree_to_root": note.optimized_degree_to_root,
                }
            )
        return RetrievalResult(ids=ref_ids, contents=ref_results)
    
    def find_related_apis(
        self,
        descriptions: Sequence[str] | None,
        num_per_desc: int = 2,
        total_k: int = 4,
        similarity_threshold: Optional[float] = None,
        stage: ActionType = ActionType.DRAFT,
    ) -> RetrievalResult:
        """Retrieve API memories that match the provided descriptions."""
        api_notes: List[APINote] = []
        seen_ids: set[str] = set()
        for desc in descriptions or []:
            try:
                notes: List[APINote] = self.retrieve(
                    query=desc,
                    top_k=num_per_desc,
                    payload_filter={"note_type": NoteType.API.value},
                    similarity_threshold=similarity_threshold,
                    stage=stage,
                )
            except Exception:
                notes = []
                logger.exception("Error searching for API memory: %s", desc)

            for note in notes:
                note_id = getattr(note, "id", None)
                if note_id is None or note_id in seen_ids:
                    continue
                seen_ids.add(note_id)
                api_notes.append(note)

        if not api_notes:
            return RetrievalResult(ids=[], contents=[])

        
        if self.cfg.enable_value_driven:
            api_notes = self._select_by_q_value(api_notes, total_k, stage=stage)
        elif self.cfg.experiment_violation_q:
            api_notes.sort(key=lambda x: x.similarity_sort_key(), reverse=True)
            api_notes = epsilon_greedy_select(api_notes, total_k, epsilon=self.cfg.epsilon)
        else:
            api_notes.sort(key=lambda x: x.similarity_sort_key(), reverse=True)
            api_notes = api_notes[:total_k]

        api_results: List[str] = []
        api_ids: List[str] = []

        for note in api_notes:
            payload = []
            if note.memory:
                payload.append(f"Description: {note.memory}")
            if note.name:
                payload.append(f"API: {note.name}")
            if note.example:
                payload.append(f"Example: {note.example}")
            if note.prototypes:
                prototypes = "\n".join(
                    f"{p['prototype']} // {p['description']}" if p.get("description") else p.get("prototype", "")
                    for p in note.prototypes
                )
                payload.append(f"Prototypes: {prototypes}")
            api_results.append("\n".join(payload))
            api_ids.append(note.id)

        return RetrievalResult(ids=api_ids, contents=api_results)

    def find_experiences(
        self,
        op: str,
        task: str,
        specific_exp_k: int = 3,
        general_exp_k: int = 3,
        exclude_ids: Optional[List[str]] = None,
        similarity_threshold: Optional[float] = None,
        stage: ActionType = ActionType.DRAFT,
    ) -> RetrievalResult:
        """
        Retrieve all experiences for the same operator and general experiences from similar operators.
        Prioritize retrieving same-operator experiences, fill with successful experiences from similar operators if insufficient,
        and fill with failed experiences if successful experiences are still insufficient. Total count remains general_exp_k+specific_exp_k.
        
        Args:
            op: Operator name
            task: Task description
            specific_exp_k: Number of specific experiences
            general_exp_k: Number of general experiences
            exclude_ids: List of experience IDs to exclude (to avoid duplicates)
        """
        if exclude_ids is None:
            exclude_ids = []
        exclude_set = set(exclude_ids)
        
        all_exp_ids: List[str] = []
        all_exp_results: List[Dict[str, Any]] = []
        total_k = general_exp_k + specific_exp_k
        
        # 1. Prioritize retrieving same-operator experiences
        specific_exp_notes: List[ExperienceNote] = self.retrieve(
            top_k=specific_exp_k,
            payload_filter={
                "$must": [
                    {"note_type": NoteType.EXPERIENCE.value},
                    {"op_name": op}
                ]
            },
            with_similarity=False,
            stage=stage,
        )
        logger.info(f"specific_exp_notes for op:{op} found {len(specific_exp_notes)}")

        # Filter out existing experiences
        for exp_note in specific_exp_notes:
            if exp_note.id not in exclude_set and exp_note.id not in all_exp_ids:
                all_exp_ids.append(exp_note.id)
                all_exp_results.append({  
                    "content": exp_note.content,
                    "source": exp_note.source.value if exp_note.source else None
                })
                if len(all_exp_ids) >= specific_exp_k:
                    break
        
        needed_count = total_k - len(all_exp_ids)
        
        if needed_count > 0:
            general_success_code_examples: List[CodeExampleNote] = self.retrieve(
                query=task,
                top_k=1,
                payload_filter={
                    "$must": [
                        {"note_type": NoteType.CODE_EXAMPLE.value},
                        {"is_correct": True},
                        {"$not": {"op_name": op}}
                    ]
                },
                similarity_threshold=similarity_threshold,
                stage=stage,
            )

            general_success_exp_notes = []
            if general_success_code_examples:
                general_success_exp_notes: List[ExperienceNote] = self.retrieve(
                    payload_filter={
                        "$must": [
                            {"note_type": NoteType.EXPERIENCE},
                            {"op_name": general_success_code_examples[0].op_name},
                            {"experience_type": ExperienceType.General},
                            {"stage": stage}
                        ]
                    },
                    with_similarity=False,
                    stage=stage,
                )
            
            for exp_note in general_success_exp_notes:
                if exp_note.id in exclude_set or exp_note.id in all_exp_ids:
                    continue
                all_exp_ids.append(exp_note.id)
                all_exp_results.append({
                    "content": exp_note.content,
                    "source": exp_note.source.value if exp_note.source else None
                })
                if len(all_exp_ids) >= needed_count:
                    break

        return RetrievalResult(ids=all_exp_ids, contents=all_exp_results)

    def find_recent_error_attempts_with_experiences(
        self,
        op: str,
        top_k: int,
        stage: ActionType = ActionType.DRAFT,
    ) -> RetrievalResult:
        """Return recent incorrect code attempts with their associated experiences."""
        code_notes: List[CodeExampleNote] = self.retrieve(
            top_k=top_k,
            payload_filter={
                "$must": [
                    {"note_type": NoteType.CODE_EXAMPLE.value},
                    {"is_correct": False},
                    {"op_name": op}
                ]
            },
            with_similarity=False,
            stage=stage,
        )
        
        if not code_notes:
            return RetrievalResult(ids=[], contents=[])
        
        code_results: List[Dict[str, Any]] = []
        code_ids: List[str] = []

        for code_note in code_notes:
            code_ids.append(code_note.id)
            reviews = code_note.reviews or []
            exp_notes: List[ExperienceNote] = self.memory_system.get_by_ids(reviews)
            code_results.append({
                "memory_id": code_note.id,
                "op_name": code_note.op_name,
                "arc_src": code_note.arc_src,
                "code": code_note.code,
                "reviews": reviews,
                "exp_contents": [exp.content for exp in exp_notes],
            })

        return RetrievalResult(ids=code_ids, contents=code_results)

    def find_best_practice(
        self,
        category: str,
        action: str,
        k: int = 1,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Find best practice experiences for optimization."""
        exp_notes: List[ExperienceNote] = self.retrieve(
            payload_filter={
                "$must": [
                    {"note_type": NoteType.EXPERIENCE},
                    {"experience_type": ExperienceType.General},
                    {"source": ExperienceSource.BEST_PRACTICE},
                    {"stage": ActionType.OPTIMIZE}
                ],
            },
            top_k=k,
            with_similarity=False,
            stage=stage,
        )
        if not exp_notes:
            exp_notes: List[ExperienceNote] = self.retrieve(
                query=f"{category}: {action}",
                top_k=k,
                payload_filter={
                    "$must": [
                        {"note_type": NoteType.EXPERIENCE}, # Experience
                        {"experience_type": ExperienceType.General},
                        {"source": ExperienceSource.BEST_PRACTICE},
                        {"stage": stage}
                    ]
                },
                stage=stage,
            )
        return RetrievalResult(
            ids=[note.id for note in exp_notes],
            contents=[note.memory for note in exp_notes],
        )

    def find_optimized_references(
        self,
        attempt: Dict,
        task: str,
        top_k: int,
        similarity_threshold: Optional[float] = None,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Search for optimized code examples that match the task description."""
        ref_results: List[Dict[str, Any]] = []
        ref_ids: List[str] = []

        notes: List[CodeExampleNote] = self.retrieve(
            query=task,
            top_k=top_k,
            payload_filter={
                "$must": [
                    {"note_type": NoteType.CODE_EXAMPLE.value},
                    {"is_correct": True},
                    {"optimized_degree_to_root": {"$gt": max(attempt.get("optimized_degree_to_root") or 1.1, 1.1)}},
                    {"stage": stage},
                    {"$not": {"op_name": attempt.get("op_name")}}
                ]
            },
            stage=stage,
            similarity_threshold=similarity_threshold,
        )

        for note in notes:
            ref_ids.append(note.id)
            optimized_exp_notes: List[ExperienceNote] = self.memory_system.get_by_ids(note.reviews)
            optimized_exp_contents = []
            for exp_note in optimized_exp_notes:
                optimized_exp_contents.append(exp_note.content)

            ref_results.append(
                {
                    "memory_id": note.id,
                    "op_name": note.op_name,
                    "arc_src": note.arc_src,
                    "code": note.code,
                    "performance": note.performance,
                    "plan": note.plan,
                    "reviews": note.reviews,
                    "optimized_degree_to_root": note.optimized_degree_to_root,
                    "optimized_degree_to_parent": note.optimized_degree_to_parent,
                    "optimized_degree_to_best": note.optimized_degree_to_best,
                    "optimized_experiences": optimized_exp_contents,
                }
            )
        return RetrievalResult(ids=ref_ids, contents=ref_results)
    
    def find_optimization_history(
        self,
        memory_id: str,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Retrieve historical optimization experiences for a given memory ID."""
        history_plans, history_ids = [], []
        
        # Initial check: ensure the note corresponding to memory_id exists
        initial_notes: List[CodeExampleNote] = self.memory_system.get_by_ids([memory_id])
        if not initial_notes:
            return RetrievalResult(ids=[], contents=[])
        
        code_note: CodeExampleNote = initial_notes[0]
        max_depth = 10  # Prevent performance issues from overly deep trees
        
        depth = 0
        while (
            code_note 
            and code_note.parent 
            and code_note.stage == stage
            and depth < max_depth
        ):
            history_plans.append({
                "plan": code_note.plan,
                "reviews": code_note.reviews,
                "optimized_degree_to_parent": code_note.optimized_degree_to_parent,
                "optimized_degree_to_root": code_note.optimized_degree_to_root,
                "optimized_degree_to_best": code_note.optimized_degree_to_best,
            })
            history_ids.append(code_note.id)
            
            parent_notes: List[CodeExampleNote] = self.memory_system.get_by_ids([code_note.parent])
            if not parent_notes:
                break
            
            code_note = parent_notes[0]
            depth += 1

        return RetrievalResult(ids=history_ids, contents=history_plans)

    def find_failed_optimize_attempts(
        self,
        parent_memory_id,
        k: int = 1,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Find failed optimization attempts for a given task."""
        failed_results: List[Dict[str, Any]] = []
        failed_ids: List[str] = []
        code_note: CodeExampleNote = self.memory_system.get_by_ids([parent_memory_id])[0]

        if code_note.children:
            failed_code_notes: List[CodeExampleNote] = self.retrieve(
                top_k=k,
                with_similarity=False,
                payload_filter = {
                    "$must": [
                        {"note_type": NoteType.CODE_EXAMPLE.value},
                        {"op_name": code_note.op_name},
                        {"parent": parent_memory_id},
                        {
                            "$or": [
                                {"is_correct": False},
                                {"optimized_degree_to_parent": {"$lte": 1}}, # Buggy/failed optimization code example
                                {"compiled": False}, 
                            ]
                        },
                        {"stage": stage}
                    ]
                },
                stage=stage,
            )
            logger.info(f"failed_code_notes for op:{code_note.op_name} found {len(failed_code_notes)}")
            if failed_code_notes:
                failed_result = {}
                for note in failed_code_notes:
                    failed_ids.append(note.id)
                    failed_result["op_name"] = note.op_name
                    failed_result["code"] = note.code
                    failed_result["plan"] = note.plan
                    failed_result["optimized_degree_to_parent"] = note.optimized_degree_to_parent
                    failed_result["optimized_degree_to_root"] = note.optimized_degree_to_root
                    failed_result["optimized_degree_to_best"] = note.optimized_degree_to_best
                    failed_result["is_correct"] = note.is_correct
                    failed_result["reviews"] = note.reviews
                    failed_exp_notes: List[ExperienceNote] = self.memory_system.get_by_ids(note.reviews)
                    failed_exp_contents = []
                    for exp_note in failed_exp_notes:
                        failed_exp_contents.append(exp_note.content)
                    failed_result["failed_experiences"] = failed_exp_contents
                    failed_results.append(failed_result)

        return RetrievalResult(ids=failed_ids, contents=failed_results)
    
    def find_succ_optimize_attempts(
        self,
        parent_memory_id: str,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """Find successful optimization attempts for a given task."""
        succ_results: List[Dict[str, Any]] = []
        succ_ids: List[str] = []
        code_note: CodeExampleNote = self.memory_system.get_by_ids([parent_memory_id])[0]

        if code_note.children:
            succ_code_notes: List[CodeExampleNote] = self.retrieve(
                with_similarity=False,
                payload_filter = {
                    "$must": [
                        {"note_type": NoteType.CODE_EXAMPLE.value},
                        {"op_name": code_note.op_name},
                        {"parent": parent_memory_id},
                        {"is_correct": True},
                        {"optimized_degree_to_parent": {"$gt": 1}}, # Successfully optimized code example
                        {"optimized_degree_to_root": {"$gt": 1}},
                        {"stage": stage}
                    ]
                },
                stage=stage,
            )
            logger.info(f"succ_code_notes for op:{code_note.op_name} found {len(succ_code_notes)}")
            if succ_code_notes:
                succ_result = {}
                for note in succ_code_notes:
                    succ_ids.append(note.id)
                    succ_result["op_name"] = note.op_name
                    succ_result["code"] = note.code
                    succ_result["plan"] = note.plan
                    succ_result["optimized_degree_to_parent"] = note.optimized_degree_to_parent
                    succ_result["optimized_degree_to_root"] = note.optimized_degree_to_root
                    succ_result["optimized_degree_to_best"] = note.optimized_degree_to_best
                    succ_result["is_correct"] = note.is_correct
                    succ_result["reviews"] = note.reviews
                    succ_exp_notes: List[ExperienceNote] = self.memory_system.get_by_ids(note.reviews)
                    succ_exp_contents = []
                    for exp_note in succ_exp_notes:
                        succ_exp_contents.append(exp_note.content)
                    succ_result["succ_experiences"] = succ_exp_contents
                    succ_results.append(succ_result)
        return RetrievalResult(ids=succ_ids, contents=succ_results)

    
    def find_failed_optimize_experiences(
        self,
        parent_memory_id: str,
        optimize_k: int = 3,
        draft_k: int = 3,
        exclude_ids: Optional[List[str]] = None,
        stage: ActionType = ActionType.OPTIMIZE,
    ) -> RetrievalResult:
        """
        Retrieve all experiences for the same operator and general experiences from similar operators.
        Prioritize retrieving same-operator experiences, fill with successful experiences from similar operators if insufficient,
        and fill with failed experiences if successful experiences are still insufficient. Total count remains general_exp_k+specific_exp_k.
        
        Args:
            op: Operator name
            task: Task description
            specific_exp_k: Number of specific experiences
            general_exp_k: Number of general experiences
            exclude_ids: List of experience IDs to exclude (to avoid duplicates)
            parent_memory_id: Parent memory id
        """
        if exclude_ids is None:
            exclude_ids = []
        exclude_set = set(exclude_ids)
        
        code_note: CodeExampleNote = self.memory_system.get_by_ids([parent_memory_id])[0]
        failed_optimize_exp_notes: List[ExperienceNote]= []
        failed_draft_exp_notes: List[ExperienceNote] = []

        if code_note.children:
            failed_optimize_code_notes: List[CodeExampleNote] = self.retrieve(
                with_similarity=False,
                payload_filter = {
                    "$must": [
                        {"note_type": NoteType.CODE_EXAMPLE.value},
                        {"op_name": code_note.op_name},
                        {"parent": parent_memory_id},
                        {"is_correct": True},
                        {"optimized_degree_to_parent": {"$lte": 1}}, # Failed optimization code example
                        {"stage": stage},
                    ]
                },
                stage=stage,
            )
            
            if failed_optimize_code_notes:
                
                for note in failed_optimize_code_notes:
                    exp_notes = self.memory_system.get_by_ids(note.reviews)
                    for exp_note in exp_notes:
                        if exp_note.id not in exclude_set:
                            failed_optimize_exp_notes.append(exp_note)

            if self.cfg.enable_value_driven:
                failed_optimize_exp_notes = self._select_by_q_value(failed_optimize_exp_notes, optimize_k, stage=stage)
            elif self.cfg.experiment_violation_q:
                failed_optimize_exp_notes = epsilon_greedy_select(failed_optimize_exp_notes, optimize_k, epsilon=self.cfg.epsilon)
            else:
                failed_optimize_exp_notes = failed_optimize_exp_notes[:optimize_k]


            failed_draft_code_notes: List[CodeExampleNote] = self.retrieve( # Compilation/correctness error code example
                with_similarity=False,
                payload_filter = {
                    "$must": [
                        {"note_type": NoteType.CODE_EXAMPLE.value},
                        {"op_name": code_note.op_name},
                        {"parent": parent_memory_id},
                        {
                            "$or": [
                                {"is_correct": False},
                                {"compiled": False},
                            ]
                        },
                        {"stage": stage}
                    ]
                },
                stage=stage,
            )
            if failed_draft_code_notes:
                
                for note in failed_draft_code_notes:
                    exp_notes = self.memory_system.get_by_ids(note.reviews)
                    for exp_note in exp_notes:
                        if exp_note.id not in exclude_set:
                            failed_draft_exp_notes.append(exp_note)
                if self.cfg.enable_value_driven:
                    failed_draft_exp_notes = self._select_by_q_value(failed_draft_exp_notes, draft_k, stage=stage)
                else:
                    failed_draft_exp_notes = self._select_by_random(failed_draft_exp_notes, draft_k)

        all_exp_notes = failed_optimize_exp_notes + failed_draft_exp_notes
        all_exp_ids = [exp.id for exp in all_exp_notes]
        all_exp_contents = [exp.content for exp in all_exp_notes]
        return RetrievalResult(ids=all_exp_ids, contents=all_exp_contents)
