

import json
import os
import math

from typing import List, Dict, Optional, Any, Literal, Union


from src.configs.memory import MemoryConfig
from src.providers import BaseLLM,BaseEmbedder,OpenAILLM, OpenAIEmbedder # TODO: Implement multiple backends
from src.memorykit.vec_dbs import QdrantVecDB, VecDBItem
from src.memorykit import BaseMemoryNote, APINote, CodeExampleNote, ExperienceNote
from src.models.agent_models import ActionType
import traceback
import threading
import logging


logger = logging.getLogger(__name__)


class MemorySystem:
    def __init__(self, embedder: BaseEmbedder, config: MemoryConfig):
        self.config: MemoryConfig = config
        self.embedder: BaseEmbedder = embedder
        # self.retrever = get_retriever() #  TODO
        self.vector_db = QdrantVecDB(config.vector_db_config)
        self._db_lock = threading.RLock()

    def _convert_to_note(self, vec_item: VecDBItem|Dict) -> BaseMemoryNote:
        """Convert a VecDBItem or a plain dict into a BaseMemoryNote subclass.

        This method logs at DEBUG level for normal conversions and logs exceptions
        with stack traces when conversion fails.
        """
        logger.info("_convert_to_note called with type=%s", type(vec_item))

        # If caller passed a raw dict, try to wrap it into a VecDBItem.
        if isinstance(vec_item, Dict):
            logger.info("Wrapping plain dict into VecDBItem")
            vec_item = VecDBItem(payload=vec_item)

        # Basic validation
        if not hasattr(vec_item, "payload"):
            logger.error("vec_item has no 'payload' attribute: %r", vec_item)
            raise ValueError("vec_item must have a 'payload' attribute")

        payload = vec_item.payload or {}
        note_type = payload.get("note_type", "")
        
        # Round score, q_value, reward_ma to 6 decimal places
        if "score" in payload and payload["score"] is not None:
            score_val = payload["score"]
            if isinstance(score_val, (int, float)) and math.isfinite(score_val):
                payload["score"] = round(float(score_val), 6)
        if "q_value" in payload and payload["q_value"] is not None:
            q_val = payload["q_value"]
            if isinstance(q_val, (int, float)) and math.isfinite(q_val):
                payload["q_value"] = round(float(q_val), 6)
        if "reward_ma" in payload and payload["reward_ma"] is not None:
            rma_val = payload["reward_ma"]
            if isinstance(rma_val, (int, float)) and math.isfinite(rma_val):
                payload["reward_ma"] = round(float(rma_val), 6)
        
        # Round score from vec_item
        vec_score = getattr(vec_item, "score", None)
        if vec_score is not None and isinstance(vec_score, (int, float)) and math.isfinite(vec_score):
            vec_score = round(float(vec_score), 6)
        
        if self.config.enable_value_driven:
            if payload.get("q_value", None) is None:
                payload["q_value"] = round(float(self.config.q_init), 6) if math.isfinite(self.config.q_init) else self.config.q_init
            if payload.get("reward_ma", None) is None:
                payload["reward_ma"] = round(float(self.config.reward_ma_init), 6) if math.isfinite(self.config.reward_ma_init) else self.config.reward_ma_init
        
        try:
            if note_type == "api":
                note = APINote(
                    id=vec_item.id,
                    score=vec_score,
                    **payload,
                )
            elif note_type == "code-example":
                note = CodeExampleNote(
                    id=vec_item.id,
                    score=vec_score,
                    **payload,
                )
            elif note_type == "experience":
                note = ExperienceNote(
                    id=vec_item.id,
                    score=vec_score,
                    **payload,
                )
            else:
                note = BaseMemoryNote(
                    id=vec_item.id,
                    score=vec_score,
                    **payload,
                )
            logger.info(
                "Converted vec_item id=%s note_type='%s' -> %s",
                getattr(vec_item, "id", None),
                note_type,
                note.__class__.__name__,
            )
            note.update_state()
            return note
        except Exception as e:
            # Log full exception with traceback to help debugging
            logger.exception(
                "Failed to convert vec_item to note (id=%s, note_type=%s): %s",
                getattr(vec_item, "id", None),
                note_type,
                e,
            )
            raise ValueError(e)
    
    def add(self, memories: list[BaseMemoryNote | dict[str, Any]]) -> None:
        # TODO: ✅
        """Add memories.
        Args:
            memories: List of MemoryNote objects or dictionaries to add.
        
        """
        # Normalize inputs: accept either Note instances or plain dicts
        memory_items: List[BaseMemoryNote] = [self._convert_to_note(m) if isinstance(m, dict) else m for m in memories]

        # Memory encode
        embed_memories = self.embedder.embed([m.memory for m in memory_items])

        # Create vector db items
        vec_db_items = []
        for item, emb in zip(memory_items, embed_memories, strict=True):
            vec_db_items.append(
                VecDBItem(
                    id=item.id,
                    payload=item.model_dump(),
                    vector=emb,
                )
            )
        with self._db_lock:
            self.vector_db.add(vec_db_items)

    def update(self, memory_id: str, new_memory: BaseMemoryNote | dict[str, Any]) -> None:
        """Update a memory by memory_id."""
        memory_item = (
            self._convert_to_note(new_memory) if isinstance(new_memory, dict) else new_memory
        )
        memory_item.id = memory_id

        vec_db_item = VecDBItem(
            id=memory_item.id,
            payload=memory_item.model_dump(),
            vector=self._embed_one_sentence(memory_item.memory),
        )
        with self._db_lock:
            self.vector_db.update(memory_id, vec_db_item)

    def search(self, query: str, top_k: int,  **kwargs) -> list[BaseMemoryNote]:
        """
        Search for memories based on a query using similarity-based retrieval.
        This method provides basic similarity-based search for direct callers.
        
        Args:
            query: Query string to search for
            top_k: Number of results to return
            **kwargs: Additional arguments:
                - filter: Payload filter dictionary
                - score_threshold: Minimum similarity score threshold (0.0-1.0 for cosine similarity)
        """
        try:
            query_vector = self._embed_one_sentence(query)
            filter_items = kwargs.get('filter', None)
            score_threshold = kwargs.get('score_threshold', None)

            with self._db_lock:
                search_results: List[VecDBItem] = self.vector_db.search(
                    query_vector, top_k, filter=filter_items, score_threshold=score_threshold
                )
            search_results = sorted(  # make higher score first
                search_results, key=lambda x: x.score, reverse=True
            )
            result_memories = [
                self._convert_to_note(search_item) for search_item in search_results
            ]
        except Exception as e:
            logger.error(f"Vector DB search failed, fallback to empty experience notes, traceback: {traceback.format_exc()}")
            result_memories = []
        return result_memories

    def get(self, memory_id: str) -> BaseMemoryNote:
        """Get a memory by its ID."""
        with self._db_lock:
            result = self.vector_db.get_by_id(memory_id)
        if result is None:
            raise ValueError(f"Memory with ID {memory_id} not found")
        return self._convert_to_note(result)

    def get_by_ids(self, memory_ids: list[str]) -> list[BaseMemoryNote]:
        """Get memories by their IDs.
        Args:
            memory_ids (list[str]): List of memory IDs to retrieve.
        Returns:
            list[BaseMemoryNote]: List of memories with the specified IDs.
        """
        with self._db_lock:
            db_items = self.vector_db.get_by_ids(memory_ids)
        memories = [self._convert_to_note(db_item) for db_item in db_items]
        return memories
    
    def get_by_filter(self, filter: dict[str, Any], scroll_limit: int = 100, with_vectors: bool = False, with_payload: bool | list[str] = True) -> list[BaseMemoryNote]:
        """Get memories by filter.
        Args:
            filter: Filter dictionary
            scroll_limit: Maximum number of items to retrieve per scroll request
            with_vectors: Whether to return vectors
            with_payload: Whether to return payload
        Returns:
            list[BaseMemoryNote]: List of memories with the specified filter.
        """
        with self._db_lock:
            db_items = self.vector_db.get_by_filter(filter, scroll_limit, with_vectors, with_payload)
        memories = [self._convert_to_note(db_item) for db_item in db_items]
        return memories
    
    def get_all(self) -> list[BaseMemoryNote]:
        """Get all memories.
        Returns:
            list[BaseMemoryNote]: List of all memories.
        """
        with self._db_lock:
            all_items = self.vector_db.get_all()
        all_memories = [self._convert_to_note(memo) for memo in all_items]
        return all_memories

    def delete(self, memory_ids: list[str]) -> None:
        """Delete a memory."""
        with self._db_lock:
            self.vector_db.delete(memory_ids)

    def delete_all(self) -> None:
        """Delete all memories."""
        with self._db_lock:
            self.vector_db.delete_collection(self.vector_db.config.collection_name)
            self.vector_db.create_collection()

    def load(self, dir: str) -> None:
        try:
            memory_file = os.path.join(dir, self.config.memory_filename)

            if not os.path.exists(memory_file):
                logger.warning(f"Memory file not found: {memory_file}")
                return

            with open(memory_file, encoding="utf-8") as f:
                memories = json.load(f)

            vec_db_items = [VecDBItem.from_dict(m) for m in memories]
            with self._db_lock:
                self.vector_db.add(vec_db_items)
            logger.info(f"Loaded {len(memories)} memories from {memory_file}")

        except FileNotFoundError:
            logger.error(f"Memory file not found in directory: {dir}")
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON from memory file: {e}")
        except Exception as e:
            logger.error(f"An error occurred while loading memories: {e}")

    def dump(self, dir: str, filename: str = None) -> None:
        """Dump memories to os.path.join(dir, self.config.memory_filename)"""
        try:
            with self._db_lock:
                all_vec_db_items = self.vector_db.get_all(with_vectors=True, with_payload=True)
            json_memories = [memory.to_dict() for memory in all_vec_db_items]

            os.makedirs(dir, exist_ok=True)
            if filename is None:
                filename = self.config.memory_filename
            memory_file = os.path.join(dir, filename)
            with open(memory_file, "w", encoding="utf-8") as f:
                json.dump(json_memories, f, indent=4, ensure_ascii=False)

            logger.info(f"Dumped {len(all_vec_db_items)} memories to {memory_file}")

        except Exception as e:
            logger.error(f"An error occurred while dumping memories: {e}")
            raise

    def drop(
        self,
    ) -> None:
        with self._db_lock:
            self.vector_db.delete_collection(self.vector_db.config.collection_name)

    def _embed_one_sentence(self, sentence: str) -> list[float]:
        """Embed a single sentence."""
        return self.embedder.embed([sentence])[0]


    def _persist_note_metadata(self, note: BaseMemoryNote) -> None:
        """Persist note payload without touching embeddings."""
        vec_db_item = VecDBItem(
            id=note.id,
            payload=note.model_dump(),
            vector=None,
        )
        with self._db_lock:
            self.vector_db.update(note.id, vec_db_item)

    def update_q_values(
        self,
        note_ids: List[str],
        note_rewards: Dict[str, Any],
        *,
        stage: ActionType,
        alpha: Optional[float] = None,
        init_q: Optional[float] = None,
        reward_ma_beta: float = 0.1,
        clip_error: Optional[float] = None,
    ) -> Dict[str, Dict[str, Any]]:
        """
        Batch update Q values for multiple memories.
        
        Args:
            note_ids: List of memory IDs to update
            note_rewards: Dictionary of reward values for each note, key is note_id, value is reward value
            stage: Stage identifier ('draft' or 'optimize')
            alpha: Learning rate (if None, use default value from MemoryConfig)
            init_q: Initial Q value (if None, use default value from MemoryConfig)
            reward_ma_beta: Beta value for reward moving average
            clip_error: Error clipping value
            
        Returns:
            Dict[str, Dict[str, Any]]: Dictionary, key is note_id, value is detailed information of update result
        """
        if not note_ids:
            return {}
        
        # Use default values from config
        if alpha is None:
            alpha = self.config.q_learning_rate
        if init_q is None:
            init_q = self.config.q_init
        
        # Batch get all notes
        notes = self.get_by_ids(note_ids)
        logger.info(
            f"[Q value update] Starting batch Q value update: note_ids={note_ids}, stage={stage}, "
            f"alpha={alpha}, init_q={init_q}, reward_ma_beta={reward_ma_beta}, clip_error={clip_error}, note_rewards={note_rewards}"
        )
        update_results = {}
        
        for note in notes:
            try:
                # Record Q value before update (using stage-specific fields)
                old_q = note.get_q_value(stage)
                fields = note._get_q_fields(stage)
                old_visits = getattr(note, fields['q_visits'])
                old_reward_ma = note.get_reward_ma(stage)
                reward = note_rewards[note.id]
                
                logger.info(
                    f"[Q value update] note_id={note.id}, stage={stage}, "
                    f"Before update: q_value={old_q}, q_visits={old_visits}, reward_ma={old_reward_ma}, "
                    f"Current reward={reward}, init_alpha={alpha}, init_q={init_q}"
                )
                update_results[note.id] = {
                    "old_q": old_q,
                    "old_visits": old_visits,
                    "old_reward_ma": old_reward_ma,
                    "reward": reward,
                }
                
                # Update Q value
                result = note.update_q(
                    reward=reward,
                    stage=stage,
                    alpha=alpha,
                    init_q=init_q,
                    reward_ma_beta=reward_ma_beta,
                    clip_error=clip_error,
                )
                update_results[note.id].update(result)
                
                # Record Q value after update
                new_reward_ma = note.get_reward_ma(stage)
                logger.info(
                    f"[Q value update] note_id={note.id}, stage={stage}, "
                    f"After update: q_value={result['new_q']}, new_q_visits={result['new_q_visits']}, "
                    f"reward_ma={new_reward_ma}, error={result['error']}, used_alpha={result['used_alpha']}"
                )
                
                # Persist updated note
                self._persist_note_metadata(note)
                
            except Exception as e:
                logger.error(f"Failed to update Q value for note {note.id}: {e}")
                logger.exception(traceback.format_exc())
                update_results[note.id] = {"error": str(e)}
        
        return update_results


    
