import os
from typing import Optional, List, Dict, Any, Union, Literal, Tuple
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource
from pathlib import Path
class DatasetConfig(BaseModel):
    name: str = Field("arguana", description="Dataset name")
    embedding_path: str = Field("./data/processed/embeddings/", description="Path to embeddings directory")
    sample_size: int = Field(5000, description="Sample size for evaluation")
    align_dimension: bool = Field(True, description="Whether to align dimensions")
    train_dataset_list: List[str] = Field(default_factory=list, description="List of training dataset names")
    test_dataset_list: List[str] = Field(default_factory=list, description="List of evaluation dataset names")
    cache_path: str = Field("./.cache/", description="Path to cache directory")
class ModelConfig(BaseModel):
    source_model: str = Field("mistral", description="Source embedding model name")
    target_model: str = Field("nv-embed", description="Target embedding model name")
    supported_models: List[str] = Field(
        default=["fast-text", "glove", "gte", "mistral", "nv-embed", "openai"],
        description="List of supported models"
    )
class LinearMapperConfig(BaseModel):
    hidden_dim: int = Field(1024, description="Hidden dimension")
    num_epochs: int = Field(1000, description="Number of training epochs")
    batch_size: int = Field(320, description="Batch size")
    learning_rate: float = Field(0.0001, description="Learning rate")
    triplet_margin: float = Field(0.5, description="Triplet margin")
    rank_margin: float = Field(0.1, description="Rank margin")
    lambda_: float = Field(1.0, description="Lambda parameter")
    hierarchy_k: int = Field(10, description="Hierarchy K parameter")
    hierarchy_weight_mode: str = Field("linear", description="Hierarchy weight mode")
class DiffusionMapperConfig(BaseModel):
    num_timesteps: int = Field(250, description="Number of diffusion timesteps")
    beta_start: float = Field(1e-4, description="Beta start value")
    beta_end: float = Field(0.02, description="Beta end value")
    hidden_dim: int = Field(1024, description="Hidden dimension")
    num_layers: int = Field(3, description="Number of layers")
    num_epochs: int = Field(1000, description="Number of training epochs")
    batch_size: int = Field(1024, description="Batch size")
    learning_rate: float = Field(1e-4, description="Learning rate")
class SPNTConfig(BaseModel):
    lambda_struct: float = Field(1.0, description="Weight for structure-preserving loss")
    struct_loss_type: str = Field("knn", description="Structure loss type")
    k_neighbors: int = Field(10, description="Number of neighbors for k-NN structure loss")
    num_anchors: int = Field(100, description="Number of anchors for anchor-based structure loss")
    num_projections: int = Field(50, description="Number of projections for sliced Wasserstein")
class SimpleLinearMapperConfig(BaseModel):
    learning_rate: float = Field(1e-4, description="Learning rate")
    num_epochs: int = Field(50, description="Number of training epochs")
    batch_size: int = Field(4028, description="Batch size for training")
    gradient_clip: float = Field(1.0, description="Gradient clipping threshold")
    weight_decay: float = Field(1e-5, description="Weight decay for regularization")
    scheduler_patience: int = Field(5, description="Learning rate scheduler patience")
    scheduler_factor: float = Field(0.5, description="Learning rate scheduler factor")
    early_stopping_patience: int = Field(10, description="Early stopping patience")
    min_delta: float = Field(1e-6, description="Minimum change for early stopping")
    device: Optional[str] = Field(None, description="Device to use (None for auto-detection)")
    layer_num: int = Field(2, description="Number of layers")
    activation: str = Field("relu", description="Activation function")
    dropout: float = Field(0.1, description="Dropout rate")
    hidden_dim: int = Field(512, description="Hidden dimension")
    use_local_distill: bool = Field(False, description="Use local distillation")
    local_k: int = Field(50, description="Number of neighbors for local distillation")
    local_tau: float = Field(0.1, description="Temperature for local distillation")
    local_weight: float = Field(0.5, description="Weight for local distillation loss")
    faiss_use_float32: bool = Field(True, description="Use float32 for FAISS index")
    knn_recompute_epochs: int = Field(0, description="Epochs to recompute KNN for local distillation")
    global_weight: float = Field(0.5, description="Weight for global model")
class GatingMoEConfig(BaseModel):
    moe_type: Literal["flat", "hierarchical", "hierarchical_lora"] = Field(
        "flat", 
        description="Type of MoE implementation to use"
    )
    num_experts: int = Field(8, description="Number of experts for flat MoE")
    clustering_method: str = Field("kmeans", description="Clustering method for expert assignment")
    distance_metric: str = Field("cosine", description="Distance metric for gating")
    random_state: int = Field(42, description="Random state for reproducibility")
    clustering_sample_size: int = Field(100000, description="Max samples for clustering")
    use_soft_routing: bool = Field(False, description="Use soft routing instead of hard assignment")
    gating_temperature: float = Field(1.0, description="Temperature for soft routing")
    num_levels: int = Field(3, description="Number of levels in hierarchy")
    branch_factor: int = Field(4, description="Branch factor (children per node)")
    lora_rank: int = Field(8, description="LoRA rank")
    lora_alpha: int = Field(16, description="LoRA alpha parameter")
    lora_dropout: float = Field(0.1, description="LoRA dropout rate")
    share_base_model: bool = Field(True, description="Share base model across LoRA experts")
    mapper_config: SimpleLinearMapperConfig = Field(
        default_factory=SimpleLinearMapperConfig, 
        description="Configuration for expert mappers"
    )
    @field_validator('num_levels')
    @classmethod
    def validate_num_levels(cls, v, info):
        if info.data.get('moe_type') in ['hierarchical', 'hierarchical_lora'] and v < 2:
            raise ValueError("num_levels must be >= 2 for hierarchical MoE")
        return v
    @field_validator('branch_factor')
    @classmethod
    def validate_branch_factor(cls, v, info):
        if info.data.get('moe_type') in ['hierarchical', 'hierarchical_lora'] and v < 2:
            raise ValueError("branch_factor must be >= 2 for hierarchical MoE")
        return v
    @field_validator('lora_rank')
    @classmethod
    def validate_lora_rank(cls, v, info):
        if info.data.get('moe_type') == 'hierarchical_lora' and v < 1:
            raise ValueError("lora_rank must be >= 1 for hierarchical_lora MoE")
        return v
class ProcrustesMapperConfig(BaseModel):
    approximate: bool = Field(False, description="Use approximate method")
    q: int = Field(1500, description="Q parameter")
    with_rotation: bool = Field(True, description="Include rotation")
    use_norm: bool = Field(True, description="Use normalization")
class LA2MMapperConfig(BaseModel):
    d_prime: int = Field(10, description="D prime parameter")
    pca_mapping: bool = Field(True, description="Use PCA mapping")
    pca_dim: int = Field(14, description="PCA dimension")
    use_norm: bool = Field(False, description="Use normalization")
    device: str = Field("auto", description="Device to use")
    batch_size: int = Field(320, description="Batch size")
    verbose: bool = Field(True, description="Verbose output")
    hidden_dims: tuple = Field((1024, 512), description="Hidden dimensions")
    alignment_strategy: str = Field("learnable", description="Alignment strategy")
    learning_rate: float = Field(1e-4, description="Learning rate")
    num_epochs: int = Field(1000, description="Number of epochs")
    loss_type: str = Field("combined", description="Loss type")
class EmbeddingConverterMapperConfig(BaseModel):
    hidden_dim: int = Field(1024, description="Hidden dimension")
    hidden_multiplier: float = Field(5.0, description="Hidden multiplier")
    num_hidden_layers: int = Field(3, description="Number of hidden layers")
    activation: str = Field("selu", description="Activation function")
    out_l2_normalize: bool = Field(True, description="Output L2 normalization")
    dropout: float = Field(0.0, description="Dropout rate")
    device: Optional[str] = Field(None, description="Device to use (None for auto-detection)")
    num_epochs: int = Field(1000, description="Number of training epochs")
    batch_size: int = Field(320, description="Batch size")
    learning_rate: float = Field(0.0001, description="Learning rate")
class Vec2VecMapperConfig(BaseModel):
    z_dim: int = Field(
        256,
        description="Latent dimension Z shared by both embedding spaces",
    )
    adapter_hidden: int = Field(
        1024,
        description="Hidden dimension of adapter MLPs",
    )
    adapter_blocks: int = Field(
        2,
        description="Number of residual blocks in adapters",
    )
    backbone_blocks: int = Field(
        4,
        description="Number of residual blocks in shared latent backbone",
    )
    disc_hidden: int = Field(
        1024,
        description="Hidden dimension of discriminators",
    )
    disc_depth: int = Field(
        3,
        description="Number of layers in discriminators",
    )
    dropout: float = Field(
        0.0,
        description="Dropout rate",
    )
    device: Optional[str] = Field(
        None,
        description="Device string (None = auto)",
    )
    num_epochs: int = Field(
        100,
        description="Number of training epochs",
    )
    batch_size: int = Field(
        4096,
        description="Batch size",
    )
    learning_rate: float = Field(
        1e-4,
        description="Learning rate for generators",
    )
    weight_decay: float = Field(
        0.0,
        description="Weight decay",
    )
    n_critic: int = Field(
        5,
        description="Number of discriminator steps per generator step",
    )
    gp_lambda: float = Field(
        10.0,
        description="Gradient penalty coefficient",
    )
    grad_clip: float = Field(
        0.0,
        description="Gradient clipping threshold (0 = disabled)",
    )
    lambda_rec: float = Field(
        1.0,
        description="Reconstruction loss weight",
    )
    lambda_cc: float = Field(
        1.0,
        description="Cycle consistency loss weight",
    )
    lambda_vsp: float = Field(
        1.0,
        description="Vector space preservation loss weight",
    )
    lambda_latent_adv: float = Field(
        1.0,
        description="Latent adversarial loss weight",
    )
    lambda_embed_adv: float = Field(
        1.0,
        description="Embedding adversarial loss weight",
    )
    lr_g: float = Field(
        1e-4,
        description="Learning rate for generators",
    )
    lr_d: float = Field(
        1e-4,
        description="Learning rate for discriminators",
    )
    weight_decay: float = Field(
        0.0,
        description="Weight decay",
    )
    n_critic: int = Field(
        5,
        description="Number of discriminator steps per generator step",
    )
    gp_lambda: float = Field(
        10.0,
        description="Gradient penalty coefficient",
    )
    grad_clip: float = Field(
        0.0,
        description="Gradient clipping threshold (0 = disabled)",
    )
    l2_normalize_inputs: bool = Field(
        True,
        description="Whether to L2-normalize input embeddings",
    )
class MapperConfig(BaseModel):
    mapper_name: str = Field("gating-moe", description="Mapper name")
    supported_mappers: List[str] = Field(
        default=["linear", "ours", "procrustes", "diffusion", "la2m", "spnt", "spnt-diffusion", "gating-moe", "simple-linear"],
        description="List of supported mappers"
    )
    loss_type: str = Field("cos", description="Loss type")
    transformed_cache_path: str = Field(".cache/transformed_embeddings/", description="Path to transformed embeddings cache")
    transformed_batch_size: int = Field(10000, description="Batch size for transformed embeddings")
    transform_strategy: Literal["cluster_then_route", "direct_route"] = Field(
        "direct_route",
        description="Transform strategy: 'cluster_then_route' (cluster first, then route) or 'direct_route' (direct routing)"
    )
    transform_num_clusters: int = Field(
        16, 
        description="Number of clusters for 'cluster_then_route' strategy"
    )
    linear: LinearMapperConfig = Field(default_factory=LinearMapperConfig)
    simple_linear: SimpleLinearMapperConfig = Field(default_factory=SimpleLinearMapperConfig)
    diffusion: DiffusionMapperConfig = Field(default_factory=DiffusionMapperConfig)
    spnt: SPNTConfig = Field(default_factory=SPNTConfig)
    gating_moe: GatingMoEConfig = Field(default_factory=GatingMoEConfig)
    procrustes: ProcrustesMapperConfig = Field(default_factory=ProcrustesMapperConfig)
    la2m: LA2MMapperConfig = Field(default_factory=LA2MMapperConfig)
    emb_conv: EmbeddingConverterMapperConfig = Field(default_factory=EmbeddingConverterMapperConfig)
    vec2vec: Vec2VecMapperConfig = Field(default_factory=Vec2VecMapperConfig)
class WandBConfig(BaseModel):
    enabled: bool = Field(True, description="Enable WandB logging")
    project: str = Field("vector_translation_cross_zenml", description="WandB project name")
    entity: Optional[str] = Field(None, description="WandB entity")
    tags: List[str] = Field(default_factory=list, description="WandB tags")
class LoggingConfig(BaseModel):
    level: str = Field("INFO", description="Logging level")
    format: str = Field("%(asctime)s - %(name)s - %(levelname)s - %(message)s", description="Log format")
    file_path: Optional[str] = Field(None, description="Log file path")
class PathsConfig(BaseModel):
    embedding_path: str = Field("./data/processed/embeddings/", description="Path to embeddings directory")
    analysis_path: str = Field("./output/analysis", description="Path to analysis output")
    analysis_cache_path: str = Field("./output/analysis_cache", description="Path to analysis cache")
    config_path: str = Field("./.vectortranslation/config.yaml", description="Path to config file")
class SingleRunConfig(BaseModel):
    dataset: DatasetConfig = Field(default_factory=DatasetConfig)
    model: ModelConfig = Field(default_factory=ModelConfig)
    mapper: MapperConfig = Field(default_factory=MapperConfig)
    wandb: WandBConfig = Field(default_factory=WandBConfig)
    logging: LoggingConfig = Field(default_factory=LoggingConfig)
    paths: PathsConfig = Field(default_factory=PathsConfig)
    run_entity: Optional[str] = Field(None, description="Run entity")
    hf_token: Optional[str] = Field(None, description="Hugging Face token")
    noco_url: Optional[str] = Field(None, description="NocoDB URL")
    noco_api_token: Optional[str] = Field(None, description="NocoDB API token")
    milvus_host: Optional[str] = Field(None, description="Milvus host")
    milvus_port: Optional[str] = Field(None, description="Milvus port")
    milvus_user: Optional[str] = Field(None, description="Milvus username")
    milvus_pwd: Optional[str] = Field(None, description="Milvus password")
class CrossTranslateConfig(BaseSettings):
    train_dataset: Optional[str] = Field(None, description="Train dataset name (legacy)")
    test_dataset: Optional[str] = Field(None, description="Test dataset name (legacy)")
    dataset: DatasetConfig = Field(default_factory=DatasetConfig)
    model: ModelConfig = Field(default_factory=ModelConfig)
    mapper: MapperConfig = Field(default_factory=MapperConfig)
    wandb: WandBConfig = Field(default_factory=WandBConfig)
    logging: LoggingConfig = Field(default_factory=LoggingConfig)
    paths: PathsConfig = Field(default_factory=PathsConfig)
    run_entity: Optional[str] = Field(None, description="Run entity")
    hf_token: Optional[str] = Field(None, description="Hugging Face token")
    noco_url: Optional[str] = Field(None, description="NocoDB URL")
    noco_api_token: Optional[str] = Field(None, description="NocoDB API token")
    milvus_host: Optional[str] = Field(None, description="Milvus host")
    milvus_port: Optional[str] = Field(None, description="Milvus port")
    milvus_user: Optional[str] = Field(None, description="Milvus username")
    milvus_pwd: Optional[str] = Field(None, description="Milvus password")
    class Config:
        env_file = ".env"
        env_prefix = "APP_"
        env_nested_delimiter = "__"
        env_file_encoding = "utf-8"
        case_sensitive = False
        validate_assignment = True
        use_enum_values = True
    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls,
        init_settings,
        env_settings,
        dotenv_settings,
        file_secret_settings,
    ):
        return (
            init_settings,                        
            env_settings,                        
            dotenv_settings,                      
            YamlConfigSettingsSource(settings_cls, yaml_path="settings.yaml"),  
            file_secret_settings,                 
        )
class Settings(BaseModel):
    cross_translate: CrossTranslateConfig = Field(..., description="Cross-dataset translation configuration")
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"
        case_sensitive = False
        validate_assignment = True
        use_enum_values = True
class YamlConfigSettingsSource(PydanticBaseSettingsSource):
    def __init__(self, settings: Settings, yaml_path: str):
        self.settings = settings
        self.yaml_path = yaml_path
        super().__init__(settings)
        setting_env_path = os.getenv("APP_SETTINGS_PATH")
        self.yaml_path = Path(setting_env_path or yaml_path or "settings.yaml")
    def __call__(self) -> dict[str, Any]:
        import yaml
        if not self.yaml_path.exists():
            return {}
        with self.yaml_path.open("r", encoding="utf-8") as f:
            data = yaml.safe_load(f) or {}
        return data
    def get_field_value(self, field, field_name: str) -> Tuple[Any, str, bool]:
        if field_name in self._data:
            return self._data[field_name], field_name, True
        return None, field_name, False
class ManyToOneConfig(BaseSettings):
    target_model: str = Field(..., description="Common target model name (e.g., 'nv-embed')")
    wandb: WandBConfig = Field(default_factory=WandBConfig)
    dataset: DatasetConfig = Field(default_factory=DatasetConfig)
    hf_token: Optional[str] = Field(None, description="Hugging Face token")
    noco_url: Optional[str] = Field(None, description="NocoDB URL")
    noco_api_token: Optional[str] = Field(None, description="NocoDB API token")
    milvus_host: Optional[str] = Field(None, description="Milvus host")
    milvus_port: Optional[str] = Field(None, description="Milvus port")
    milvus_user: Optional[str] = Field(None, description="Milvus username")
    milvus_pwd: Optional[str] = Field(None, description="Milvus password")
    runs: List[SingleRunConfig] = Field(
        default_factory=list,
        description="List of SingleRunConfig objects, each defining one source→target run",
    )
    def share_settings_to_runs(self):
        for run in self.runs:
            run.hf_token = self.hf_token
            run.noco_url = self.noco_url
            run.noco_api_token = self.noco_api_token
            run.milvus_host = self.milvus_host
            run.milvus_port = self.milvus_port
            run.milvus_user = self.milvus_user
            run.milvus_pwd = self.milvus_pwd
            run.wandb = self.wandb
    class Config:
        env_file = ".env"
        env_prefix = "APP_"
        env_nested_delimiter = "__"
        env_file_encoding = "utf-8"
        case_sensitive = False
        validate_assignment = True
        use_enum_values = True
    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls,
        init_settings,
        env_settings,
        dotenv_settings,
        file_secret_settings,
    ):
        return (
            init_settings,                        
            env_settings,                        
            dotenv_settings,                      
            YamlConfigSettingsSource(settings_cls, yaml_path="many_to_one.yaml"),  
            file_secret_settings,                 
        )
class TransitivityCaseConfig(BaseModel):
    run_ab: SingleRunConfig
    run_bc: SingleRunConfig
    run_ac: SingleRunConfig
    metric: str = "cosine"
    k_list: List[int] = Field(default_factory=lambda: [10, 50, 100])
class TransitivityConfig(BaseSettings):
    wandb: WandBConfig = Field(default_factory=WandBConfig)
    dataset: DatasetConfig = Field(default_factory=DatasetConfig)
    hf_token: Optional[str] = Field(None, description="Hugging Face token")
    noco_url: Optional[str] = Field(None, description="NocoDB URL")
    noco_api_token: Optional[str] = Field(None, description="NocoDB API token")
    milvus_host: Optional[str] = Field(None, description="Milvus host")
    milvus_port: Optional[str] = Field(None, description="Milvus port")
    milvus_user: Optional[str] = Field(None, description="Milvus username")
    milvus_pwd: Optional[str] = Field(None, description="Milvus password")
    cases: List[TransitivityCaseConfig] = Field(default_factory=list)
    def share_settings_to_runs(self):
        shared = {
            "wandb": self.wandb,
            "dataset": self.dataset,
            "hf_token": self.hf_token,
            "noco_url": self.noco_url,
            "noco_api_token": self.noco_api_token,
            "milvus_host": self.milvus_host,
            "milvus_port": self.milvus_port,
            "milvus_user": self.milvus_user,
            "milvus_pwd": self.milvus_pwd,
        }
        for case in self.cases:
            for run in (case.run_ab, case.run_bc, case.run_ac):
                for k, v in shared.items():
                    setattr(run, k, v)
    class Config:
        env_file = ".env"
        env_prefix = "APP_"
        env_nested_delimiter = "__"
        env_file_encoding = "utf-8"
        case_sensitive = False
        validate_assignment = True
        use_enum_values = True
    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls,
        init_settings,
        env_settings,
        dotenv_settings,
        file_secret_settings,
    ):
        return (
            init_settings,                        
            env_settings,                        
            dotenv_settings,                      
            YamlConfigSettingsSource(settings_cls, yaml_path="transitivity.yaml"),  
            file_secret_settings,                 
        )
