"""
Configuration system using OmegaConf for the Dense Retrieval training framework.
"""

from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from omegaconf import OmegaConf


@dataclass
class DataConfig:
    """Data paths and collection configuration."""

    dataset_name: str = "unknown"  # Dataset name for logging and caching
    collection_path: str = ""
    queries_train_path: str = ""
    qrels_train_path: str = ""
    queries_dev_path: str = ""
    qrels_dev_path: str = ""
    use_positive_only_collection: bool = False
    collection_positive_only_path: Optional[str] = (
        None  # If use_positive_only_collection is True
    )
    eval_collection_path: Optional[str] = (
        None  # Separate collection for evaluation (optional, defaults to training collection)
    )


@dataclass
class ModelConfig:
    """Model architecture configuration."""

    encoder_name: str = "distilbert-base-uncased"
    pooling_strategy: str = "cls"  # Options: "cls", "mean"
    embedding_dim: int = 768
    normalize_embeddings: bool = True
    max_query_length: int = 64
    max_doc_length: int = 128
    query_prefix: str = ""  # Prefix to prepend to queries (e.g., "query: " for e5 models)
    document_prefix: str = ""  # Prefix to prepend to documents (e.g., "passage: " for e5 models)


@dataclass
class LossConfig:
    """Loss function configuration."""

    name: str = "infonce"
    temperature: float = 1.0
    use_mined_negatives: bool = False
    use_sampled_negatives: bool = False
    use_inbatch_negatives: bool = True
    gather_across_gpus: bool = (
        True  # Whether to gather in-batch negatives across GPUs in DDP
    )


@dataclass
class TrainingConfig:
    """Training hyperparameters."""

    num_epochs: int = 40
    per_gpu_batch_size: int = 400
    gradient_accumulation_steps: int = 4
    max_positives_per_query: int = 1
    max_mined_negatives_per_query: int = 1
    max_sampled_negatives_per_query: int = 0

    # Optimizer settings
    optimizer: str = "adamw"
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    warmup_steps: int = 1000

    # Mixed precision
    fp16: bool = True

    # Checkpointing and early stopping
    checkpoint_frequency: int = 1  # Save checkpoint every N epochs
    early_stopping_patience: int = 5
    early_stopping_metric: str = "ndcg@10"  # Metric to use for early stopping

    # Intermediate evaluation within epochs (for initial epochs only)
    epoch_intermediate_eval: bool = False  # Enable step-wise evaluation
    num_epoch_intermediate_eval: int = 1  # Number of initial epochs to apply this to
    epoch_eval_percentages: List[int] = field(
        default_factory=lambda: [25, 50, 75]
    )  # Percentages at which to evaluate

    # Evaluation control
    enable_evaluation: bool = True  # Enable/disable evaluation during training

    # Logging
    log_steps: int = 100  # Log training loss every N steps


@dataclass
class EmbeddingGenerationConfig:
    """Configuration for embedding generation."""

    enabled: bool = True
    frequency: int = (
        1  # Generate embeddings every N epochs (0, frequency, 2*frequency, ...)
    )
    generate_queries: bool = True
    generate_documents: bool = True
    batch_size: int = 7000  # Batch size for embedding generation

    # If enabled, samplers will use embeddings cached during the training forward
    # pass of the previous epoch (epoch i uses cache from epoch i-1).
    # Missing IDs needed by samplers (e.g., additional positives) will be encoded
    # explicitly on demand.
    cache_query_embeddings: bool = False
    cache_document_embeddings: bool = False


@dataclass
class BatchSamplerConfig:
    """Configuration for batch sampling strategy."""

    name: str = "random"  # Options: "random", "hobit", etc.
    enabled: bool = True
    frequency: int = 1  # Run batch sampler every N epochs
    args: Dict[str, Any] = field(default_factory=dict)  # Sampler-specific arguments


@dataclass
class NegativeSamplerConfig:
    """Configuration for negative sampling strategy."""

    name: str = "random"  # Options: "random", "ance", etc.
    enabled: bool = False
    frequency: int = 1  # Run negative sampler every N epochs
    num_samples: int = 1  # Number of negative samples per query
    args: Dict[str, Any] = field(default_factory=dict)  # Sampler-specific arguments


@dataclass
class LoggingConfig:
    """Logging configuration."""

    log_level: str = "INFO"
    log_file: str = "training.log"
    tensorboard_dir: str = "log/tensorboard"


@dataclass
class ExperimentConfig:
    """Main experiment configuration that aggregates all sub-configs."""

    experiment_name: str = ""
    experiment_dir: str = ""
    seed: int = 42

    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    loss: LossConfig = field(default_factory=LossConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    embedding_generation: EmbeddingGenerationConfig = field(
        default_factory=EmbeddingGenerationConfig
    )
    batch_sampler: BatchSamplerConfig = field(default_factory=BatchSamplerConfig)
    negative_sampler: NegativeSamplerConfig = field(
        default_factory=NegativeSamplerConfig
    )
    logging: LoggingConfig = field(default_factory=LoggingConfig)


def load_config(config_path: str) -> ExperimentConfig:
    """
    Load configuration from a YAML file.

    Args:
        config_path: Path to the YAML configuration file

    Returns:
        ExperimentConfig object
    """
    # Load YAML file
    yaml_config = OmegaConf.load(config_path)

    # Merge with structured config to get defaults
    schema = OmegaConf.structured(ExperimentConfig)
    config = OmegaConf.merge(schema, yaml_config)

    # Convert to dataclass instance
    config_obj = OmegaConf.to_object(config)

    return config_obj


def save_config(config: ExperimentConfig, save_path: str) -> None:
    """
    Save configuration to a YAML file.

    Args:
        config: ExperimentConfig object
        save_path: Path where to save the YAML file
    """
    # Convert to OmegaConf
    omega_config = OmegaConf.structured(config)

    # Save to YAML
    OmegaConf.save(omega_config, save_path)


def create_example_config() -> ExperimentConfig:
    """
    Create an example configuration for reference.

    Returns:
        Example ExperimentConfig object
    """
    return ExperimentConfig(
        experiment_name="msmarco-random_qb-random_ns",
        experiment_dir="./experiments/msmarco-random_qb",
        seed=42,
        data=DataConfig(
            dataset_name="msmarco",
            collection_path="./data/MSMARCO/collection.tsv",
            queries_train_path="./data/MSMARCO/queries.train.tsv",
            qrels_train_path="./data/MSMARCO/qrels.train.tsv",
            queries_dev_path="./data/MSMARCO/queries.dev.tsv",
            qrels_dev_path="./data/MSMARCO/qrels.dev.tsv",
            use_positive_only_collection=False,
        ),
        model=ModelConfig(
            encoder_name="distilbert-base-uncased",
            pooling_strategy="cls",
            embedding_dim=768,
            normalize_embeddings=True,
            max_query_length=64,
            max_doc_length=128,
        ),
        loss=LossConfig(
            name="infonce",
            temperature=1.0,
            use_mined_negatives=False,
            use_sampled_negatives=False,
            use_inbatch_negatives=True,
        ),
        training=TrainingConfig(
            num_epochs=40,
            per_gpu_batch_size=400,
            gradient_accumulation_steps=4,
            max_positives_per_query=1,
            max_mined_negatives_per_query=1,
            max_sampled_negatives_per_query=0,
            optimizer="adamw",
            learning_rate=1e-4,
            weight_decay=0.01,
            adam_epsilon=1e-8,
            max_grad_norm=1.0,
            warmup_steps=1000,
            fp16=True,
            checkpoint_frequency=1,
            early_stopping_patience=5,
            early_stopping_metric="ndcg@10",
            log_steps=100,
        ),
        embedding_generation=EmbeddingGenerationConfig(
            enabled=True,
            frequency=1,
            generate_queries=True,
            generate_documents=True,
            batch_size=7000,
        ),
        batch_sampler=BatchSamplerConfig(
            name="random",
            enabled=True,
            frequency=1,
            args={},
        ),
        negative_sampler=NegativeSamplerConfig(
            name="random",
            enabled=False,
            frequency=1,
            num_samples=1,
            args={},
        ),
        logging=LoggingConfig(
            log_level="INFO",
            log_file="training.log",
            tensorboard_dir="log/tensorboard",
        ),
    )


if __name__ == "__main__":
    # Example usage
    config = create_example_config()
    save_config(config, "example_config.yaml")
    print("Example config saved to example_config.yaml")

    # Load it back
    loaded_config = load_config("example_config.yaml")
    print(f"\nLoaded config experiment name: {loaded_config.experiment_name}")
