from typing import Union, Literal

from src.configs.base import BaseConfig
from pydantic import Field, model_validator

from src.configs.utils import extract_timestamp_suffix
from pathlib import Path

import logging
logger = logging.getLogger(__name__)


class BaseVecDBConfig(BaseConfig):
    """Base class for all vector database configurations."""

    collection_name: str = Field(..., description="Name of the collection")
    vector_dimension: int | None = Field(default=None, description="Dimension of the vectors")
    distance_metric: Literal["cosine", "euclidean", "dot"] | None = Field(
        default=None,
        description="Distance metric for vector similarity calculation. Options: 'cosine', 'euclidean', 'dot'",
    )

    @model_validator(mode="after")
    def normalize_collection_name(self) -> "BaseVecDBConfig":
        suffix = extract_timestamp_suffix(self.collection_name)
        if suffix is None:
            self.collection_name = f"{self.collection_name}_{self.timestamp}"
        return self

class QdrantVecDBConfig(BaseVecDBConfig):
    """Configuration for Qdrant vector database."""

    host: str | None = Field(default=None, description="Host for Qdrant")
    port: int | None = Field(default=None, description="Port for Qdrant")
    path: str | None = Field(default=None, description="Path for Qdrant")

    @model_validator(mode="after")
    def set_default_path(self):
        if all(x is None for x in (self.host, self.port, self.path)):
            logger.warning(
                "No host, port, or path provided for Qdrant. Defaulting to local path: %s",
                Path.cwd()/ "qdrant",
            )
            self.path = str(Path.cwd() / "qdrant")
        return self

class MemoryConfig(BaseConfig):
    """Configuration for memory service.

    This contains settings that control how memories are built, stored and
    optionally value-driven updates.
    """
    # Database related settings
    vector_db_config: Union[QdrantVecDBConfig, BaseVecDBConfig] = Field(
        default_factory=QdrantVecDBConfig,
        description=""
    )
    
    max_keywords: int = Field(default=8, description="Maximum number of keywords to extract")
    memory_confidence: float = Field(default=100.0, description="Default confidence score for new memories")
    memory_filename: str = Field(
        "memory.json",
        description="Filename for storing memories",
    )
    add_similarity_threshold: float = Field(default=0.90, description="Similarity threshold when deciding to merge/add memories")
    enable_value_driven: bool = Field(default=True, description="Enable value-driven (RL) components")
    save_all_correct_attempts: bool = Field(default=False, description="Save all correct attempts as memories")
    
    # Novelty assessment settings
    enable_novelty_check: bool = Field(default=False, description="Enable novelty assessment: check similarity before adding new successful experiences. If highly similar to existing experience, credit reward to existing one instead of adding new")
    novelty_similarity_threshold: float = Field(default=0.85, ge=0.0, le=1.0, description="Similarity threshold for novelty check. If new experience similarity >= threshold with existing experience, consider it non-novel")
    novelty_reward_share: float = Field(default=0.5, ge=0.0, le=1.0, description="When non-novel experience is found, share this fraction of reward to the existing similar experience (0.0-1.0)")

    # Retrieval related defaults
    # correct_top_k: int = Field(default=1, ge=1, description="Top-k for correct reference search")
    error_code_top_k: int = Field(default=1, ge=1, description="Top-k for incorrect code attempts")
    error_exp_top_k: int = Field(default=3, ge=1, description="Top-k for error experiences")
    api_top_k_per_desc: int = Field(default=2, ge=1, description="API search top-k per description")
    api_max_results: int = Field(default=4, ge=1, description="Max APIs returned per query")
    compile_exp_top_k: int = Field(default=2, ge=1, description="Top-k debug/compile experiences per error description")
    compile_api_max_results: int = Field(default=8, ge=1, description="Max APIs returned for compile errors")

    # RL related defaults (used when value-driven enabled)
    q_init: float = Field(default=0.0, description="Initial Q value")
    reward_ma_init: float = Field(default=0.0, description="Initial reward moving average")
    q_learning_rate: float = Field(default=0.1, description="Q learning rate (alpha)")
    q_discount: float = Field(default=0.99, description="Discount factor for Q updates (gamma)")
    exploration_strategy: Literal["epsilon_greedy", "boltzmann"] = Field(
        default="epsilon_greedy", 
        description="Exploration strategy: 'epsilon_greedy' (ε-greedy), 'boltzmann' (Boltzmann exploration)"
    )
    epsilon: float = Field(default=0.1, description="Exploration rate for ε-greedy selection (only used when exploration_strategy='epsilon_greedy')")
    temperature: float = Field(default=1.0, gt=0.0, description="Temperature parameter for Boltzmann exploration (higher = more exploration, only used when exploration_strategy='boltzmann')")
    tau: float = Field(default=0.35, description="Similarity threshold used for unknown detection")
    success_reward: float = Field(default=1.0, description="Reward applied when a memory leads to success")
    failure_reward: float = Field(default=-1.0, description="Penalty when usage fails")
    value_selector_topk: int = Field(default=5, description="Candidate pool size before exploration selection")
    candidate_pool_multiplier: float = Field(default=1.0, ge=1, description="Multiplier for candidate pool size in value-based retrieval. The system will retrieve top_k * candidate_pool_multiplier candidates, then select top_k based on Q-value")
    max_candidate_pool_multiplier: float = Field(default=15.0, ge=1, description="Max multiplier for candidate pool size in value-based retrieval. The system will retrieve top_k * min(candidate_pool_multiplier, max_candidate_pool_multiplier) candidates, then select top_k based on Q-value")
    sim_threshold: float = Field(default=0.80, description="Similarity threshold when retrieving memories")
    reward_ma_beta: float = Field(default=0.1, description="Beta value for reward moving average")
    clip_error: float = Field(default=1.0, description="Error clip value")
    
    # Utility pruning settings
    enable_utility_pruning: bool = Field(default=False, description="Enable utility-based pruning: periodically remove memories with low Q values or rarely retrieved")
    pruning_q_threshold: float = Field(default=-0.5, description="Q value threshold for pruning. Memories with Q value below this will be considered for removal")
    pruning_retrieval_threshold: int = Field(default=0, ge=0, description="Minimum retrieval count threshold. Memories retrieved fewer times than this will be considered for removal")
    pruning_age_days: int = Field(default=30, ge=1, description="Minimum age in days before a memory can be pruned. Prevents removing newly added memories")
    pruning_batch_size: int = Field(default=100, ge=1, description="Number of memories to evaluate in each pruning batch")
    
    # Concept drift adaptation settings
    enable_concept_drift: bool = Field(default=False, description="Enable concept drift adaptation: prioritize recent successful experiences based on timestamps")
    drift_time_decay_factor: float = Field(default=0.1, gt=0.0, description="Time decay factor for concept drift. Higher values give more weight to recent experiences. Applied as: weight = exp(-decay_factor * days_old)")
    drift_recent_success_bonus: float = Field(default=0.2, ge=0.0, description="Bonus weight added to recent successful experiences (within drift_recent_window_days)")
    drift_recent_window_days: int = Field(default=7, ge=1, description="Time window (in days) for considering an experience as 'recent'")

    # Optimization threshold settings
    optimization_threshold: float = Field(default=1.0, description="Optimization threshold for considering a code example as optimized")

    # Just for ablation study
    experiment_violation_q: bool = Field(default=False, description="Ablation study: violate Q-value usage in retrieval")
