from typing import List, Dict, Any, Optional
import logging
import os
import traceback
import json
import shutil
import math
from datetime import datetime, timezone, timedelta
from src.configs.config import project_root_path
from threading import Lock
from src.utils import (
    trim_long_string,
    normalize_op_name,
    safe_read_json,
    read_file,
    calculate_optimized_degree,

)
from src.providers.base import BaseLLM, LLMOutput
from src.prompts.memory_prompt import (
    generate_draft_round_succ_summary_template,
    generate_draft_round_err_summary_template,
    generate_optimize_round_succ_summary_template,
    generate_optimize_round_failed_summary_template,
)
from src.memorykit.note import (
    ExperienceNote,
    CodeExampleNote,
    ExperienceType,
    ExperienceSource,
    APINote,
    NoteType,
)

from src.models.agent_models import ActionType
from src.prompts.utils import read_example_files
from src.configs.memory import MemoryConfig
from src.memorykit.memory import MemorySystem
from examples.dataset import category2exampleop
logger = logging.getLogger(__name__)

# Constants
NOVELTY_CHECK_TOP_K = 5  # Number of most similar experiences to search during novelty check
DEFAULT_REWARD_MA_BETA = 0.1  # Default reward moving average beta value


def clip(x: float, lo: float, hi: float) -> float:
    return max(lo, min(hi, x))

    
class MemoryCurator:
    def __init__(self, memory_system: MemorySystem, llm: BaseLLM, config: MemoryConfig):
        self.config: MemoryConfig = config
        self.llm: BaseLLM = llm
        self.memory_system: MemorySystem = memory_system
        # Record reward list for each task in optimization stage, used for normal normalization
        # key: op_name, value: List[float] - all rewards/best performance for this task
        self.optimize_stage_rewards: Dict[str, List[float]] = {}
        self.block_performance_lock: Lock = Lock()
        self.root_performance_lock: Lock = Lock()
        self.best_performance: Dict[str, Optional[float]] = {}
        self.root_performance: Dict[str, Optional[float]] = {}

    
    # ======================================================== #
    # run time memory management
    # ======================================================== #
    def save_running_info(self, save_dir):
        try:
            os.makedirs(save_dir, exist_ok=True)
            with open(os.path.join(save_dir, "root_performance.json"), "w") as f:
                json.dump(self.root_performance, f, indent=4)
            with open(os.path.join(save_dir, "best_performance.json"), "w") as f:
                json.dump(self.best_performance, f, indent=4)
            with open(os.path.join(save_dir, "optimize_stage_rewards.json"), "w") as f:
                json.dump(self.optimize_stage_rewards, f, indent=4)
            logger.info(f"Saved running info to directory: {save_dir}")
        except Exception as e:
            logger.error(f"Failed to save running info: {e}")
            logger.exception(traceback.format_exc())

    def load_running_info(self, save_dir):
        try:
            logger.info(f"Loading running info from directory: {save_dir}")
            if not os.path.exists(os.path.join(save_dir, "root_performance.json")):
                self.root_performance = {}
            else:
                with open(os.path.join(save_dir, "root_performance.json"), "r") as f:
                    self.root_performance = json.load(f)
            if not os.path.exists(os.path.join(save_dir, "best_performance.json")):
                self.best_performance = {}
            else:
                with open(os.path.join(save_dir, "best_performance.json"), "r") as f:
                    self.best_performance = json.load(f)
            if not os.path.exists(os.path.join(save_dir, "optimize_stage_rewards.json")):
                self.optimize_stage_rewards = {}
            else:
                with open(os.path.join(save_dir, "optimize_stage_rewards.json"), "r") as f:
                    self.optimize_stage_rewards = json.load(f)
            logger.info(f"Loaded running info from directory: {save_dir}")
            if self.root_performance:
                logger.info(f"Root performance: {self.root_performance}")
            if self.best_performance:
                logger.info(f"Best performance: {self.best_performance}")
            if self.optimize_stage_rewards:
                logger.info(f"Optimize stage rewards: {self.optimize_stage_rewards}")
        except Exception as e:
            logger.error(f"Failed to load running info: {e}")
            logger.exception(traceback.format_exc())

    def summarize_and_update(self, op_name: str, arc_src: str, round_info: Dict):
        status_info = {}
        action = round_info.get('action')
        is_correct, is_optimized = False, False
        step_optimized_degree = None
        optimized_degree = None  # Optimization degree for reward calculation (relative or absolute value)
        if action == ActionType.DRAFT:
            is_correct = self._summarize_draft_round(
                op = op_name,
                arc_src = arc_src,
                round_info = round_info,
            )
        elif action == ActionType.OPTIMIZE:
            is_correct, is_optimized, step_optimized_degree, optimized_degree_to_best_child = self._summarize_optimize_round(
                op = op_name,
                arc_src = arc_src,
                round_info = round_info,
            )
            optimized_degree = optimized_degree_to_best_child if optimized_degree_to_best_child is not None else step_optimized_degree

        else: 
            logger.warning(f"Unknown action: {action}")
            raise ValueError(f"Unknown action: {action}")
            
        if self.config.enable_value_driven:
            status_info.update(self.update_q_values_from_verification(
                retrieved_note_ids=round_info.get('retrieved_note_ids', []),
                stage=action,
                is_correct=is_correct,
                is_optimized=is_optimized,
                optimized_degree=optimized_degree,
                op_name=op_name,  # Pass op_name for task grouping
            ))
        
        return status_info

    def update_q_values_from_verification(
        self,
        retrieved_note_ids: List[str],
        stage: ActionType,
        is_correct: bool = False,
        is_optimized: bool = False,
        optimized_degree: Optional[float] = None,
        *,
        op_name: Optional[str] = None,
        note_rewards: Optional[Dict[str, Any]] = None,
        alpha: Optional[float] = None,
        init_q: Optional[float] = None,
        reward_ma_beta: float = None,
        clip_error: Optional[float] = None,
    ) -> Dict[str, Dict[str, Any]]:
        """Update Q values of retrieved memories based on verification results.

        Args:
            retrieved_note_ids: List of retrieved memory IDs
            verified_result: Verification result dictionary containing 'compiled' and 'correctness' fields
            op_name: Task name for grouping reward records (needed for optimization stage)
            note_rewards: Reward value dictionary for each note, key is note_id, value is reward value, used to extend reward forms
            alpha: Learning rate (if None, use default value from config)
            init_q: Initial Q value (if None, use default value from config)
            reward_ma_beta: Beta value for reward moving average
            clip_error: Error clipping value

        Returns:
            Dict[str, Dict[str, Any]]: Dictionary with note_id as key and detailed update result information as value
        """
        if not retrieved_note_ids:
            return {}

        # Calculate reward based on verification results
        if note_rewards is None:
            if stage == ActionType.DRAFT:
                reward_value = self.config.success_reward if is_correct else self.config.failure_reward
            elif stage == ActionType.OPTIMIZE:
                if optimized_degree is not None:
                    reward_value = math.tanh(math.log(optimized_degree))
                    # reward_value = optimized_degree
                else:
                    reward_value = self.config.failure_reward

                # Record raw reward in optimization stage and perform normal normalization
                if op_name is not None:
                    reward_value = self._record_and_normalize_reward(reward_value, op_name, add_to_list=True)
            else:
                raise ValueError(f"Unknown stage: {stage}")

            note_rewards = {note_id: reward_value for note_id in retrieved_note_ids}
            logger.info(
                f"[Q-value update triggered] Calculated reward based on verification results: "
                f"is_correct={is_correct}, is_optimized={is_optimized}, "
                f"step_optimized_degree={optimized_degree}, reward_value={reward_value},"
                f"note_ids={retrieved_note_ids}"
            )
        else:
            logger.info(
                f"[Q-value update triggered] Using provided rewards: note_rewards={note_rewards}, "
                f"note_ids={retrieved_note_ids}"
            )

        # Use config default values
        alpha = alpha or self.config.q_learning_rate
        init_q = init_q or self.config.q_init
        reward_ma_beta = reward_ma_beta if reward_ma_beta is not None else self.config.reward_ma_beta
        clip_error = clip_error if clip_error is not None else self.config.clip_error
        
        logger.info(
            f"[Q-value update triggered] Preparing to update Q values: "
            f"alpha={alpha}, init_q={init_q}, reward_ma_beta={reward_ma_beta}, "
            f"clip_error={clip_error}, memories to update={len(retrieved_note_ids)}"
        )

        # Batch update Q values
        try:
            update_results = self.memory_system.update_q_values(
                note_ids=retrieved_note_ids,
                note_rewards=note_rewards,
                stage=stage,
                alpha=alpha,
                init_q=init_q,
                reward_ma_beta=reward_ma_beta,
                clip_error=clip_error,
            )
            logger.info(
                f"[Q-value update completed] Successfully updated Q values for {len(update_results)} memories. "
                f"Reward details:\n{json.dumps(note_rewards, indent=4)}"
            )
            return update_results
        except Exception as e:
            logger.error(f"Failed to update Q values: {e}")
            logger.exception(traceback.format_exc())
            return {}
    
    def _record_and_normalize_reward(self, reward: float, op_name: str, add_to_list: bool = True) -> float:
        """Record reward in optimization stage and perform normal normalization.

        For the same task in optimization stage, record all rewards and use z-score for normalization:
        normalized_reward = (reward - mean) / std

        Args:
            reward: Raw reward value
            op_name: Task name for grouping records and normalization

        Returns:
            Normalized reward value (returns raw value if normalization is not possible)
        """
        logger.info(f"Recording reward for op {op_name}: {reward}")
        if op_name not in self.optimize_stage_rewards:
            self.optimize_stage_rewards[op_name] = []
        
        rewards = self.optimize_stage_rewards[op_name]

        # If reward count is insufficient or standard deviation is too small, normalization cannot be performed
        if len(rewards) < 2:
            logger.debug(
                f"[Reward normalization] Task {op_name} has insufficient rewards ({len(rewards)}), "
                f"skipping normalization, using raw reward: {reward:.6f}"
            )
            if add_to_list:
                self.optimize_stage_rewards[op_name].append(reward) # Record raw reward for normal normalization
            return reward

        # Calculate mean and standard deviation
        mean = sum(rewards) / len(rewards)
        variance = sum((r - mean) ** 2 for r in rewards) / len(rewards)
        std = math.sqrt(variance)
        
        if std < 1e-6:
            logger.debug(
                f"[Reward normalization] Task {op_name} has too small standard deviation ({std:.6f}), "
                f"skipping normalization, using raw reward: {reward:.6f}"
            )
            if add_to_list:
                self.optimize_stage_rewards[op_name].append(reward) # Record raw reward for normal normalization
            return reward

        # z-score normalization
        normalized_reward = (reward - mean) / std


        logger.info(
            f"[Reward normalization] Task {op_name}: "
            f"raw={reward:.4f}, mean={mean:.4f}, std={std:.4f}, "
            f"normalized={normalized_reward:.4f} (total {len(rewards)} rewards)"
        )
        if add_to_list:
            self.optimize_stage_rewards[op_name].append(reward) # Record raw reward for normal normalization
        return normalized_reward
    
    def update_best_code(self, new_note: CodeExampleNote):

        search_results = self.memory_system.search(
            query=new_note.memory, top_k=1,
            filter={"note_type": "code-example", "op_name": new_note.op_name, "is_correct": True}
        )
        logger.debug(f"Search results for op {new_note.op_name}: {len(search_results)} results")
        
        note: CodeExampleNote = None
        if isinstance(search_results, list):
            note = search_results[0] if search_results else None
        else:
            note = search_results

        if not note:
            logger.info(f"Adding best code result to memory for op {new_note.op_name}")
            self.memory_system.add([new_note])
            
        else:
            current_performance = getattr(note, "performance", None)
            if isinstance(current_performance, dict):
                current_mean = current_performance.get("mean")
            else:
                current_mean = None

            performance_mean = None
            if isinstance(new_note.performance, dict):
                performance_mean = new_note.performance.get('mean', None)

            if performance_mean is not None and (
                current_mean is None or performance_mean < current_mean
            ):
                logger.info(f"Updating best code result to memory for op {new_note.op_name}")
                note.code = new_note.code
                note.reviews = new_note.reviews
                note.performance = new_note.performance
                note.is_compilable = new_note.is_compilable
                note.is_correct = new_note.is_correct
                note.arc_src = new_note.arc_src
                self.memory_system.update(note.id, new_memory=note)

    def _summarize_draft_round(self, op: str, arc_src: str, round_info: Dict):
        compiled = round_info.get('compiled', False)
        correct = round_info.get('correctness', False)
        plan = round_info.get('plan', '')
        code = round_info.get('code', '')
        error_attempt = round_info.get('error_attempt', '')
        performance = round_info.get('performance', {})
        error_exps = round_info.get('error_exps', [])
        experience_notes: List[ExperienceNote] = []
        code_note: CodeExampleNote = None

        if compiled and correct:
            prompt = generate_draft_round_succ_summary_template(
                op=op,
                arc_src=arc_src,
                verified_code=code,
                verified_plan=plan,
                error_attempt=error_attempt,
                error_exps=error_exps,
            )
            logger.info(f"Draft summary query preview for {op}:\n{prompt}")
            llm_output: LLMOutput = self.llm.generate_single(prompt, extract_first_block=True, language_type="json")
            response_json = llm_output.response_txt
            logger.info(f"LLM raw summary response for {op}:\n{response_json}")

            for entry in response_json:
                note = ExperienceNote(
                    memory=entry['content'],
                    content=entry['content'],
                    stage=ActionType.DRAFT,
                    experience_type=ExperienceType.General if entry['type'] == ExperienceType.General.value else ExperienceType.OperatorSpecific,
                    source=ExperienceSource.DEBUG_SUCCESS if error_attempt else ExperienceSource.DRAFT_SUCCESS,
                    op_name=op
                )
                experience_notes.append(note)
            
            # Novelty check: if enabled, check similarity and credit reward to existing experiences
            if self.config.enable_novelty_check:
                experience_notes = self._check_novelty_and_update(experience_notes)
            
            code_note = CodeExampleNote(
                memory=op + "\n" + arc_src,
                code=code,
                reviews=[note.id for note in experience_notes],
                op_name=op,
                stage=ActionType.DRAFT,
                arc_src=arc_src,
                is_compilable=compiled,
                is_correct=correct,
                performance=performance,
            )

            if self.config.save_all_correct_attempts:
                self.memory_system.add([code_note])
            else:
                self.update_best_code(code_note)

        else:
            err_msg = round_info.get("compile_info", "") or round_info.get("correctness_info", "")
            err_msg = trim_long_string(err_msg) if err_msg else ""
            prompt = generate_draft_round_err_summary_template(
                op=op,
                arc_src=arc_src,
                err_code=code,
                err_plan=plan,
                err_msg=err_msg,
            )
            logger.info(f"Draft summary query preview for {op}:\n{prompt}")
            llm_output: LLMOutput = self.llm.generate_single(prompt, extract_first_block=True, language_type="json")
            response_json = llm_output.response_txt
            logger.info(f"LLM raw summary response for {op}:\n{response_json}")

            for entry in response_json:
                note = ExperienceNote(
                    memory=entry['content'],
                    content=entry['content'],
                    stage=ActionType.DRAFT,
                    experience_type=ExperienceType.General if entry['type'] == ExperienceType.General else ExperienceType.OperatorSpecific,
                    source=ExperienceSource.BASIC_FAILURE,
                    op_name=op
                )
                experience_notes.append(note)
            if self.config.enable_novelty_check:
                experience_notes = self._check_novelty_and_update(experience_notes)
            code_note = CodeExampleNote(
                memory=op + "\n" + arc_src,
                code=code,
                reviews=[note.id for note in experience_notes],
                op_name=op,
                stage=ActionType.DRAFT,
                arc_src=arc_src,
                is_compilable=compiled,
                is_correct=False,
                performance=performance,
            )
            self.memory_system.add([code_note])
            

        # Only add notes that weren't filtered out by novelty check
        notes_to_add = [note for note in experience_notes if note is not None]
        if notes_to_add:
            self.memory_system.add(notes_to_add)
        
        return compiled and correct

    def _summarize_optimize_round(self, op: str, arc_src: str, round_info: Dict):
        parent_info = round_info.get('parent', {})
        compiled = round_info.get('compiled', False)
        correct = round_info.get('correctness', False)
        
        experience_notes: List[ExperienceNote] = []
        code_note: CodeExampleNote = None
        stage = ActionType.OPTIMIZE
        parent_memory_id = parent_info.get("memory_id", "")
        parent_code = parent_info.get("code", "")
        parent_performance = parent_info.get("performance", {})
        prompt = None
        exp_source = None

        current_code = round_info.get('code', '')
        current_plan = round_info.get('plan', '')

        current_performance = round_info.get("performance", {})
        if isinstance(current_performance, dict):
            current_performance_mean = current_performance.get("mean", None)
        else:
            current_performance_mean = None
        if isinstance(parent_performance, dict):
            parent_performance_mean = parent_performance.get("mean", None)
        else:
            parent_performance_mean = None
        step_optimized_degree = calculate_optimized_degree(current_performance, parent_performance)
        optimized_degree_to_root = None
        # Calculate optimization degree relative to best performance
        optimized_degree_to_best_child = None
        if current_performance_mean is not None and current_performance_mean > 0:
            # Get or initialize best performance
            with self.block_performance_lock:
                if op not in self.best_performance or self.best_performance[op] is None:
                    self.best_performance[op] = parent_performance_mean if parent_performance_mean is not None else current_performance_mean
            
            with self.root_performance_lock:
                if op not in self.root_performance or self.root_performance[op] is None:
                    self.root_performance[op] = parent_performance_mean if parent_performance_mean is not None else current_performance_mean

                # Ensure root_performance[op] is not None before calculation
                if self.root_performance[op] is not None and self.root_performance[op] > 0:
                    optimized_degree_to_root = self.root_performance[op] / current_performance_mean
            
            with self.block_performance_lock:
                # Calculate optimization degree relative to best performance
                if self.best_performance[op] is not None and self.best_performance[op] > 0:
                    optimized_degree_to_best_child = self.best_performance[op] / current_performance_mean
                    # Update best performance: if current performance is better (smaller), update it
                    if current_performance_mean < self.best_performance[op]:
                        self.best_performance[op] = current_performance_mean

        is_optimized = False
        is_correct = False

        if compiled and correct:
            is_correct = True
            if optimized_degree_to_best_child is not None:
                is_optimized = optimized_degree_to_best_child > self.config.optimization_threshold
                # Only generate failure summary when absolute value doesn't reach threshold
                # Note: When relative value <= threshold, don't generate failure summary because:
                # - Relative value <= threshold means "not better than current best child", which is normal exploration
                # - Avoid inconsistency: First child may succeed with absolute value, subsequent children may fail with relative value,
                #   If both generate failure summaries, same code would be judged as success/failure at different stages, causing logical confusion
                if not is_optimized and step_optimized_degree <= self.config.optimization_threshold:
                    prompt = generate_optimize_round_failed_summary_template(
                        op=op,
                        arc_src=arc_src,
                        current_code=current_code,
                        current_plan=current_plan,
                        parent_code=parent_code,
                        parent_performance=parent_performance,
                        current_performance=current_performance,
                    )
                    exp_source = ExperienceSource.OPTIMIZE_FAILURE
            else:
                is_optimized = step_optimized_degree > self.config.optimization_threshold
                # Optimization failed, generate failure summary
                if not is_optimized:
                    prompt = generate_optimize_round_failed_summary_template(
                        op=op,
                        arc_src=arc_src,
                        current_code=current_code,
                        current_plan=current_plan,
                        parent_code=parent_code,
                        parent_performance=parent_performance,
                        current_performance=current_performance,
                    )
                    exp_source = ExperienceSource.OPTIMIZE_FAILURE

        else:
            # Handle failure cases: may be compilation error or optimization failure
            err_msg = round_info.get("compile_info", "") or round_info.get("correctness_info", "")
            err_msg = trim_long_string(err_msg) if err_msg else ""
            exp_source = ExperienceSource.BASIC_FAILURE
            if err_msg:
                prompt = generate_draft_round_err_summary_template(
                    op=op,
                    arc_src=arc_src,
                    err_code=current_code,
                    err_plan=current_plan,
                    err_msg=err_msg,
                )
            exp_source = ExperienceSource.BASIC_FAILURE
        
        if prompt:
            logger.info(f"Optimize summary query preview for {op}:\n{prompt}")
            llm_output: LLMOutput = self.llm.generate_single(prompt, extract_first_block=True, language_type="json")
            response_json = llm_output.response_txt
            logger.info(f"LLM raw summary response for {op}:\n{response_json}")

            for entry in response_json:
                note = ExperienceNote(
                    memory=entry['content'],
                    content=entry['content'],
                    stage=stage,
                    experience_type=ExperienceType.General if entry['type'] == ExperienceType.General.value else ExperienceType.OperatorSpecific,
                    source=exp_source,
                    op_name=op
                )
                experience_notes.append(note)
            if self.config.enable_novelty_check:
                experience_notes = self._check_novelty_and_update(experience_notes)

        # add code note
        q_init_optimize = math.tanh(math.log(optimized_degree_to_best_child)) *self.config.q_learning_rate if optimized_degree_to_best_child is not None else 0
        q_init_draft = math.tanh(math.log(optimized_degree_to_root)) *self.config.q_learning_rate if optimized_degree_to_root is not None else 0
        code_note = CodeExampleNote(
            memory=op + "\n" + arc_src,
            code=current_code,
            reviews=[note.id for note in experience_notes] if experience_notes else [],
            op_name=op,
            plan=current_plan,
            arc_src=arc_src,
            is_compilable=compiled,
            is_correct=is_correct,
            optimize_q_value= q_init_optimize,
            draft_q_value = q_init_draft,
            optimized_degree_to_parent=step_optimized_degree,
            optimized_degree_to_best=optimized_degree_to_best_child,  
            optimized_degree_to_root=optimized_degree_to_root,
            performance=current_performance,
            parent=parent_memory_id,
            stage=stage,
        )
        # update parent note
        parent_note: CodeExampleNote = self.memory_system.get(parent_memory_id)
        if parent_note:
            parent_note.children.append(code_note.id)
            self.memory_system.update(parent_note.id, new_memory=parent_note)
        self.memory_system.add([code_note])
    
        # add experience notes
        notes_to_add = [note for note in experience_notes if note is not None]
        if notes_to_add:
            self.memory_system.add(notes_to_add)

        return is_correct, is_optimized, step_optimized_degree, optimized_degree_to_best_child

    def _check_novelty_and_update(
        self,
        new_experience_notes: List[ExperienceNote],
    ) -> List[Optional[ExperienceNote]]:
        """Check novelty of new experiences, and give positive rewards to existing experiences if highly similar.

        If similar experiences are summarized, it indicates this experience is important and the same problem may occur frequently.
        Therefore, give positive rewards to existing experiences to increase their Q values and make them easier to retrieve.

        Args:
            new_experience_notes: List of newly generated experiences

        Returns:
            Filtered experience list, if an experience is similar to existing ones, the corresponding position is None (not added)
        """

        if not self.config.enable_novelty_check:
            return new_experience_notes
        threshold = self.config.novelty_similarity_threshold
        reward_share = self.config.novelty_reward_share

        # For similar experiences, uniformly give positive rewards
        # Because the appearance of similar experiences indicates this experience is important, need to increase its Q value for easier retrieval
        reward = self.config.success_reward * reward_share
        
        filtered_notes = []
        credited_count = 0
        
        for note in new_experience_notes:
            if not note.memory:
                filtered_notes.append(note)
                continue
            
            # Search for similar experiences
            similar_experiences: List[ExperienceNote] = self.memory_system.search(
                query=note.memory,
                top_k=5,  # Check top 5 most similar
                filter={"note_type": NoteType.EXPERIENCE.value, "experience_type": note.experience_type.value, "op_name": note.op_name},
                score_threshold=threshold
            )
            logger.info(f"Similar experiences for {note.memory}: {len(similar_experiences)} results")
            
            if similar_experiences:
                # Found similar experience, select the most similar one
                most_similar: ExperienceNote = similar_experiences[0]
                try:
                    # Use experience's stage, if not available default to draft
                    experience_stage = note.stage if hasattr(note, 'stage') and note.stage else ActionType.DRAFT
                    if self.config.enable_value_driven:
                        self.memory_system.update_q_values(
                            note_ids=[most_similar.id],
                            note_rewards={most_similar.id: reward},
                            stage=experience_stage,
                        )
                    logger.info(
                        f"Novelty check: Found the most similar experience: {most_similar.id}, "
                        f"similarity={getattr(most_similar, 'score', 0.0):.3f}). content: {most_similar.content} "
                        f"Similar experience indicates importance, crediting positive reward {reward:.3f} "
                        f"to increase Q-value and improve retrieval priority. Skipping new experience."
                    )
                    credited_count += 1
                except Exception as e:
                    logger.error(f"Failed to update Q value for similar experience {most_similar.id}: {e}")
                    filtered_notes.append(note)
            else:
                filtered_notes.append(note)
        
        if credited_count > 0:
            logger.info(
                f"Novelty check: {credited_count}/{len(new_experience_notes)} experiences "
                f"were similar to existing ones. Credited positive rewards to increase their Q-values "
                f"and improve retrieval priority, instead of adding duplicate experiences."
            )
        
        return filtered_notes
    
    
    # ======================================================== #
    # initialize memory
    # ======================================================== #
    def memory_init(self):
        """Initialize memory system.

        If saved memory exists, load it; otherwise create new memory and save it.
        """
        memory_backup_dir = os.path.join(
            self.config.vector_db_config.path, 
            'all_memory'
        )
        memory_file_path = os.path.join(
            memory_backup_dir, 
            self.config.memory_filename
        )
        
        if os.path.exists(memory_file_path):
            try:
                self.memory_system.load(memory_backup_dir)
                logger.info("Memory loaded successfully from backup directory")
            except Exception as e:
                logger.error(f"Failed to load memory: {e}")
                logger.exception(traceback.format_exc())
                logger.info("Initializing new memory instead...")
                self._initialize_new_memory(memory_backup_dir)
        else:
            logger.info("Memory backup not found, initializing new memory...")
            self._initialize_new_memory(memory_backup_dir)
    
    def _initialize_new_memory(self, backup_dir: str):
        """Initialize new memory system.

        Args:
            backup_dir: Backup directory path
        """
        try:
            breakpoint()
            self.memory_add_select_shot()
            self.memory_add_add_shot()
            self.memory_add_apis()
            self.memory_system.dump(dir=backup_dir)
            logger.info(f"New memory initialized and saved to {backup_dir}")
        except Exception as e:
            logger.error(f"Failed to initialize new memory: {e}")
            logger.exception(traceback.format_exc())

    def memory_add_best_practice(self):
        """Add best practices to memory system.

        Load best practices from JSON file and add them as general experiences to memory system.
        """
        best_practice_path = "data/best_practices_with_keys.json"
    
        try:
            best_practices = safe_read_json(best_practice_path, default={})
            logger.info(f"Loaded {len(best_practices)} best practice entries from {best_practice_path}")
        except Exception as e:
            logger.error(f"Failed to load best practices from {best_practice_path}: {e}")
            return

        notes_to_add = []
        
        for abstract, best_practice in best_practices.items():
            note = ExperienceNote(
                memory=abstract+'\n'+best_practice,
                content=best_practice,
                abstract=abstract,
                stage=ActionType.OPTIMIZE,
                experience_type=ExperienceType.General,
                source=ExperienceSource.BEST_PRACTICE
            )
            notes_to_add.append(note)

        if notes_to_add:
            self.memory_system.add(notes_to_add)
            logger.info(f"Successfully added {len(notes_to_add)} best practices to memory system.")
        else:
            logger.warning("No valid best practices to add.")


    def memory_add_apis(self):
        apis_path = "data/cleaned/apis/ascendc_dev_guide_sections/above_148.filter.en.json"
        apis_path = "data/cleaned/apis/ascendc_dev_guide_sections/above_148.failure.en.json"
        apis = safe_read_json(apis_path,default=[])
        len_apis = len(apis)
        added = 0
        logger.info(f"[info] loading api: apis={len_apis}")
        for api in apis:
            try:
                memory_str = f"API: {api['operator_name']}\n"
                example_str = None
                if api['functional_description']:
                    memory_str += f"Description: {api['functional_description']}\n"
                if api['usage_example'] and 'code' in api['usage_example'].keys():
                    example_str = api['usage_example']['code']
                    memory_str += f"Example: {example_str}"
                    
                self.memory_system.add([
                    APINote(
                        name=api['operator_name'] if api['operator_name'] else "",
                        memory=memory_str,
                        example=example_str,
                        prototypes = api['function_prototypes'],
                        parameters=api['parameters'],
                        return_value=api['return_value']
                    )
                ])
                added += 1
                logger.info(f"[info] api added for {api['operator_name']}; remaining: {len_apis - added}")
            except Exception as e:
                logger.error(f"[error] add(api) failed for {api['operator_name']}: {traceback.format_exc()}")
        logger.info(f"[info] apis loaded: {added} api notes.")
  
    def memory_add_select_shot(self):
        for category, op_name in category2exampleop.items():
            example_arch_path = os.path.join(
                project_root_path, 'src', f"prompts/cuda_model_{op_name}.py"
            )
            example_new_arch_path = os.path.join(
                project_root_path, 'src', f"prompts/ascendc_new_model_{op_name}.py"
            )
            example_arch = read_file(example_arch_path)
            example_new_arch = read_file(example_new_arch_path)
            op_norm = normalize_op_name(op_name)

            self._summarize_draft_round(
                op=op_norm, arc_src=example_arch, 
                round_info={
                    "compiled": True, 
                    "correctness": True, 
                    "code": example_new_arch, 
                    "plan": "There is no plan provided here. Please analyze it based on the code."
                }
            )

    def memory_add_add_shot(self):
        op_name = "add"
        op_norm = normalize_op_name(op_name)
        example_arc_src, example_new_arc_src = read_example_files('ascendc', op_name)
        self._summarize_draft_round(
            op=op_norm, arc_src=example_arc_src,
            round_info={
                "compiled": True, 
                "correctness": True, 
                "code": example_new_arc_src, 
                "plan": "There is no plan provided here. Please analyze it based on the code."
            }
        )
    
    # ======================================================== #
    # utility pruning
    # ======================================================== #
    def prune_low_utility_memories(
        self,
        *,
        q_threshold: Optional[float] = None,
        retrieval_threshold: Optional[int] = None,
        age_days: Optional[int] = None,
        batch_size: Optional[int] = None,
        dry_run: bool = False,
    ) -> Dict[str, Any]:# TODO: Need to modify to prune based on stage
        """Periodically review memory repository and remove experiences with persistently low Q values or rarely retrieved.

        Args:
            q_threshold: Q value threshold, memories below this value will be considered for removal (if None, use config value)
            retrieval_threshold: Retrieval count threshold, memories below this value will be considered for removal (if None, use config value)
            age_days: Minimum age (days) to prevent removing newly added memories (if None, use config value)
            batch_size: Number of memories to evaluate each time (if None, use config value)
            dry_run: If True, only return statistics without actual deletion

        Returns:
            Dictionary containing statistics: {
                'evaluated': Number of memories evaluated,
                'pruned': Number of memories removed,
                'low_q_count': Number of memories with too low Q values,
                'low_retrieval_count': Number of memories with too few retrievals,
                'too_young_count': Number of memories too new to be protected,
                'pruned_ids': List of removed memory IDs
            }
        """
        if not self.config.enable_utility_pruning:
            logger.info("Utility pruning is disabled in config")
            return {
                'evaluated': 0,
                'pruned': 0,
                'low_q_count': 0,
                'low_retrieval_count': 0,
                'too_young_count': 0,
                'pruned_ids': []
            }
        
        # Use config value or passed parameter
        q_threshold = q_threshold or self.config.pruning_q_threshold
        retrieval_threshold = retrieval_threshold or self.config.pruning_retrieval_threshold
        age_days = age_days or self.config.pruning_age_days
        batch_size = batch_size or self.config.pruning_batch_size
        
        # Get all memories
        all_memories = self.memory_system.get_all()
        total_count = len(all_memories)
        
        if total_count == 0:
            logger.info("No memories to prune")
            return {
                'evaluated': 0,
                'pruned': 0,
                'low_q_count': 0,
                'low_retrieval_count': 0,
                'too_young_count': 0,
                'pruned_ids': []
            }
        
        # Calculate minimum creation time (age_days days ago)
        min_creation_time = datetime.now(timezone.utc) - timedelta(days=age_days)

        # Evaluate memories
        pruned_ids = []
        low_q_count = 0
        low_retrieval_count = 0
        too_young_count = 0
        
        # Batch process to avoid memory issues
        for i in range(0, total_count, batch_size):
            batch = all_memories[i:i + batch_size]
            
            for note in batch:
                should_prune = False
                reason = []
                
                # Check age: don't delete memories that are too new
                if isinstance(note.created_at, datetime):
                    if note.created_at > min_creation_time:
                        too_young_count += 1
                        continue

                # Check Q value
                q_val = note.q_value if note.q_value is not None else float('-inf')
                if q_val < q_threshold:
                    should_prune = True
                    reason.append(f"low_q({q_val:.3f})")
                    low_q_count += 1
                
                # Check retrieval count
                retrieval_count = getattr(note, 'retrieval_count', 0)
                if retrieval_count < retrieval_threshold:
                    should_prune = True
                    reason.append(f"low_retrieval({retrieval_count})")
                    low_retrieval_count += 1

                # If conditions are met, mark for deletion
                if should_prune:
                    pruned_ids.append(note.id)
                    if not dry_run:
                        logger.debug(f"Pruning memory {note.id}: {', '.join(reason)}")
        
        # Execute deletion
        if pruned_ids and not dry_run:
            try:
                self.memory_system.delete(pruned_ids)
                logger.info(
                    f"Pruned {len(pruned_ids)} memories out of {total_count} evaluated. "
                    f"Low Q: {low_q_count}, Low retrieval: {low_retrieval_count}, "
                    f"Too young (protected): {too_young_count}"
                )
            except Exception as e:
                logger.error(f"Failed to prune memories: {e}")
                logger.exception(traceback.format_exc())
                pruned_ids = []
        elif pruned_ids and dry_run:
            logger.info(
                f"Dry run: Would prune {len(pruned_ids)} memories out of {total_count} evaluated. "
                f"Low Q: {low_q_count}, Low retrieval: {low_retrieval_count}, "
                f"Too young (protected): {too_young_count}"
            )
        
        return {
            'evaluated': total_count,
            'pruned': len(pruned_ids) if not dry_run else 0,
            'low_q_count': low_q_count,
            'low_retrieval_count': low_retrieval_count,
            'too_young_count': too_young_count,
            'pruned_ids': pruned_ids if not dry_run else pruned_ids,
        }

            