import os
import yaml

from typing import Dict, List, Optional, Union, Any
from pydantic import Field, field_validator, model_validator

from src.configs.base import BaseConfig
from src.configs.provider import (
    OpenAILLMConfig,
    AzureEmbedderConfig,
    OpenAIEmbedderConfig,
    AzureLLMConfig,
    BaseLLMConfig,
    BaseEmbedderConfig,
)
from src.configs.memory import MemoryConfig
from src.configs.verify import VerifyConfig
from src.memorykit.note import NoteType

from src.models.selector_models import SelectorStrategy

class SelectorConfig(BaseConfig):
    strategy: SelectorStrategy = Field(
        default=SelectorStrategy.LATEST,
        description="Record selection strategy"
    )
    
    epsilon: float = Field(default=0.1, ge=0.0, le=1.0, description="e_greedy exploration rate")
    top_k: int = Field(default=5, ge=1, description="K for top_k_random")
    temperature: float = Field(default=1.0, gt=0.0, description="softmax temperature coefficient")
    

class PlanAgentConfig(BaseConfig):
    enabled_memory_types: List[NoteType] = Field(
        default_factory=lambda: [
            NoteType.API, 
            NoteType.CODE_EXAMPLE, 
            NoteType.EXPERIENCE
        ],
        description="Enabled memory note types for retrieval."
    )
    
    @field_validator("enabled_memory_types")
    @classmethod
    def validate_enabled_memory_types(cls, v: List[NoteType]) -> List[NoteType]:
        if not v:
            return []

        valid_types = {NoteType.API, NoteType.CODE_EXAMPLE, NoteType.EXPERIENCE}
        invalid = [t for t in v if t not in valid_types]
        if invalid:
            raise ValueError(f"Invalid types: {invalid}. Valid types are: {sorted(valid_types)}")

        return v
    

class CodeAgentConfig(BaseConfig):
    enabled_stages: List[str] = Field(
        default_factory=lambda: ["draft", "debug_compile", "debug_correct", "optimize"],
        description="List of enabled execution stages for the agent"
    )
    max_draft_attempts: int = Field(
        default=3,
        description="Maximum number of draft attempts."
    )
    max_debug_compile_attempts: int = Field(
        default=5,
        description="Maximum number of debug/compile attempts."
    )
    max_debug_correct_attempts: int = Field(
        default=5,
        description="Maximum number of debug/correct attempts."
    )
    max_optimize_attempts: int = Field(
        default=2,
        description="Maximum number of optimization attempts."
    )
    
    enabled_memory_types: List[str] = Field(
        default_factory=lambda: [
            NoteType.API.value,
            NoteType.CODE_EXAMPLE.value,
            NoteType.EXPERIENCE.value,
        ],
        description="Memory note types enabled for retrieval/injection",
    )
    
    @field_validator("enabled_stages")
    @classmethod
    def validate_enabled_stages(cls, v: List[str]) -> List[str]:
        """Validate that enabled stages are non-empty and contain valid stage names."""
        if not v:
            raise ValueError("enabled_stages cannot be empty")
        
        valid_stages = {"draft", "debug_compile", "debug_correct", "optimize"}
        for stage in v:
            if stage not in valid_stages:
                raise ValueError(f"Invalid stage: {stage}. Valid stages are: {valid_stages}")
        
        return v
    
    @field_validator("enabled_memory_types")
    @classmethod
    def validate_enabled_memory_types(cls, v: List[str]) -> List[str]:
        """Validate that enabled stages are non-empty and contain valid stage names."""
        if not v:
            return v
        
        valid_types = {nt.value for nt in NoteType}
        normalized = []
        for t in v:
            if t in valid_types:
                normalized.append(t)
            else:
                raise ValueError(f"Invalid type: {t}. Valid types are: {valid_types}")
        return normalized
    

class AgentConfig(BaseConfig):
    """Top-level agent configuration that aggregates all sub-configs.

    This model is intended to be serialized to/from disk so the full agent
    runtime configuration can be saved and restored.
    """
    # Basic identity
    project_root_path: str = os.getcwd()
    logs_path: str = os.path.join(project_root_path, 'logs')
    src_base_path: str = os.path.join(project_root_path, 'src')
    ref_impl_base_path: str = os.path.join(project_root_path, 'reference')
    seed_num: int = 1024
    user_id: str = Field(default="default_user", description="User identifier")
    exp_name: str = Field(default="test", description="User identifier")
    
    code_agent_config: CodeAgentConfig = Field(
        default_factory=CodeAgentConfig,
        description="Code agent configuration"
    )
    
    # Provider configs (LLM and embedder)
    llm_config: Union[OpenAILLMConfig, AzureLLMConfig, BaseLLMConfig] = Field(
        default_factory=OpenAILLMConfig,
        description="Configuration for the LLM provider"
    )
    embedder_config: Union[OpenAIEmbedderConfig, AzureEmbedderConfig, BaseEmbedderConfig] = Field(
        default_factory=OpenAIEmbedderConfig,
        description="Configuration for the embedding provider"
    )

    # Memory subsystem configuration
    memory_config: MemoryConfig = Field(
        default_factory=MemoryConfig,
        description="Memory service configuration"
    )
    
    verify_config: VerifyConfig = Field(
        default_factory=VerifyConfig,
        description="Memory service configuration"
    )
    
    seletor_config: SelectorConfig = Field(
        default_factory=SelectorConfig,
        description=""
    )

    # Runtime settings
    outer_iters: int = Field(default=25, description="")
    outer_start: int = Field(default=1, description="")
    max_iters: int = Field(default=16, description="")
    debug_mode: bool = Field(default=False, description="Enable debug mode")
    max_workers: int = Field(default=8,  description="")
    log_level: str = Field(default="INFO", description="Logging level")
    
    @field_validator("log_level")
    @classmethod
    def validate_log_level(cls, v: str) -> str:
        """Validate that the provided log level is one of the accepted values."""
        allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
        v = v.upper()
        if v not in allowed:
            raise ValueError(f"Invalid log level. Must be one of: {allowed}")
        return v
    
    @model_validator(mode="after")
    def unify_timestamp(self) -> "AgentConfig":
        """Unify timestamps of all sub-configs to the main config's timestamp.
        
        This method ensures that AgentConfig and all its sub-config objects use the same timestamp value.
        If user specifies timestamp in YAML, all sub-configs will use this value;
        If not specified, will use the default generated timestamp, and all sub-configs will be unified.
        """
        main_timestamp = self.timestamp
        
        # Unify timestamps of all sub-configs
        sub_configs = [
            self.code_agent_config,
            self.llm_config,
            self.embedder_config,
            self.memory_config,
            self.verify_config,
            self.seletor_config,
        ]
        
        for sub_config in sub_configs:
            if isinstance(sub_config, BaseConfig):
                sub_config.timestamp = main_timestamp
                # If sub-config has nested BaseConfig objects (e.g., vector_db_config in memory_config), also unify
                if hasattr(sub_config, 'vector_db_config') and isinstance(sub_config.vector_db_config, BaseConfig):
                    sub_config.vector_db_config.timestamp = main_timestamp
        
        return self
    def save_to_file(self, filepath: str) -> None:
        """Serialize the configuration to JSON and save it to `filepath`."""
        with open(filepath, "w", encoding="utf-8") as f:
            f.write(self.model_dump_json(indent=2, ensure_ascii=False))

    @classmethod
    def load_from_file(cls, filepath: str) -> "AgentConfig":
        """Load configuration from a JSON file and return an `AgentConfig` instance."""
        with open(filepath, "r", encoding="utf-8") as f:
            return cls.model_validate_json(f.read())
    
    def save_to_yaml(self, filepath: str) -> None:
        """Serialize the config to YAML and save to filepath."""
        data = self.model_dump()  # returns a Python dict
        with open(filepath, "w", encoding="utf-8") as f:
            yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False)

    @classmethod
    def from_yaml_file(cls, file_name: str, base_dir: str | None = None) -> "AgentConfig":
        """Load config from a YAML file and return an AgentConfig instance."""
        if base_dir:
            filepath = os.path.join(base_dir, file_name)
        else:
            filepath = os.path.join('src', 'configs', 'presets',file_name)

        with open(filepath, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        return cls.model_validate(data)

if __name__ == "__main__":
    cfg_from_yaml = AgentConfig.from_yaml_file('agent.default.yaml')
    print(type(cfg_from_yaml.llm_config))
    print("\n--- Dump (YAML) ---")
    print(yaml.safe_dump(cfg_from_yaml.model_dump(), allow_unicode=True, sort_keys=False))