"""
Configuration management using pydantic.

Loads and validates configuration from config.yml.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Any, Union
from pathlib import Path
import os

from pydantic import BaseModel, Field, field_validator
import yaml


# Global config singleton
_config: Optional['Config'] = None


class GraphConfig(BaseModel):
    """Graph generation configuration."""
    families: List[str] = ["erdos_renyi", "chain", "star", "complete", "tree"]
    d_values: List[int] = [5, 10, 20, 50, 100]

    class ErdosRenyiConfig(BaseModel):
        p_values: List[float] = [0.1, 0.2, 0.3]
        p_dense: float = 0.5

    class TreeConfig(BaseModel):
        branching_factor: int = 2

    class StarConfig(BaseModel):
        center_is_parent: bool = True

    erdos_renyi: ErdosRenyiConfig = Field(default_factory=ErdosRenyiConfig)
    tree: TreeConfig = Field(default_factory=TreeConfig)
    star: StarConfig = Field(default_factory=StarConfig)


class SEMConfig(BaseModel):
    """SEM parameter configuration."""
    beta_min: float = 0.3
    beta_max: float = 0.6
    beta_min_weak: float = 0.1
    beta_max_weak: float = 0.2
    beta_min_strong: float = 0.5
    beta_max_strong: float = 1.0
    sigma_min: float = 1.0
    sigma_max: float = 1.0
    sign_distribution: str = "random"

    @field_validator('beta_min', 'beta_max', 'sigma_min', 'sigma_max')
    @classmethod
    def validate_positive(cls, v: float) -> float:
        if v <= 0:
            raise ValueError("Must be positive")
        return v

    @field_validator('sign_distribution')
    @classmethod
    def validate_sign(cls, v: str) -> str:
        if v not in ['random', 'positive', 'negative']:
            raise ValueError("Must be 'random', 'positive', or 'negative'")
        return v


class PCConfig(BaseModel):
    """PC algorithm configuration."""
    alpha: float = 0.05
    max_cond_set_size: Optional[int] = None
    ci_threshold: float = 1e-10

    @field_validator('alpha')
    @classmethod
    def validate_alpha(cls, v: float) -> float:
        if not 0 < v < 1:
            raise ValueError("Alpha must be in (0, 1)")
        return v


class FisherConfig(BaseModel):
    """Fisher dimension computation configuration."""
    method: str = "direct"
    max_mec_enumeration: int = 1000
    mec_sample_size: int = 100

    @field_validator('method')
    @classmethod
    def validate_method(cls, v: str) -> str:
        if v not in ['direct', 'curvature']:
            raise ValueError("Method must be 'direct' or 'curvature'")
        return v


class Exp1Config(BaseModel):
    """Experiment 1 configuration."""
    enabled: bool = True
    description: str = "Verify F([G]) predicts empirical sample complexity"
    graphs_per_family: int = 500
    n_values: List[int] = [100, 200, 500, 1000, 2000, 5000]
    n_trials: int = 50
    success_threshold: float = 0.9
    d: int = 10
    expected_correlation: float = 0.8


class Exp2Config(BaseModel):
    """Experiment 2 configuration."""
    enabled: bool = True
    description: str = "Verify theoretical bounds are tight"
    d: int = 10
    graphs_per_family: int = 100
    n_values: List[int] = [100, 200, 500, 1000, 2000, 5000]
    n_trials: int = 30
    success_threshold: float = 0.9
    expected_ratio_min: float = 0.5
    expected_ratio_max: float = 2.0
    expected_within_factor: float = 0.8


class Exp3Config(BaseModel):
    """Experiment 3 configuration."""
    enabled: bool = True
    description: str = "Compare Fisher dimension with other proxies"
    d_values: List[int] = [10, 20]
    graphs_per_d: int = 500
    n_values: List[int] = [100, 200, 500, 1000, 2000, 5000]
    n_trials: int = 30
    success_threshold: float = 0.9
    expected_fisher_rank: float = 0.9
    proxies: List[str] = [
        "fisher_dimension", "graph_density", "max_in_degree",
        "avg_markov_blanket", "num_v_structures", "mec_size",
        "curvature_estimate"
    ]


class Exp4Config(BaseModel):
    """Experiment 4 configuration."""
    enabled: bool = True
    description: str = "Verify log(d) scaling"
    d_values: List[int] = [5, 10, 15]  # Moderate sizes for tractable PC
    graphs_per_d: int = 30
    n_values: List[int] = [100, 200, 500, 1000, 2000, 5000]
    n_trials: int = 15
    success_threshold: float = 0.9
    families: List[str] = ["erdos_renyi", "chain", "tree"]


class Exp5Config(BaseModel):
    """Experiment 5 configuration."""
    enabled: bool = True
    description: str = "Verify lower bound from Theorem 4.2"
    d: int = 8
    num_pairs: int = 100
    n_values: List[int] = [50, 100, 200, 500, 1000, 2000]
    n_trials: int = 50
    target_power: float = 0.8
    expected_fraction_exceeding: float = 0.9


class BenchmarkGraphInfo(BaseModel):
    """Benchmark graph information."""
    name: str
    description: str
    d: int
    edges: int


class Exp6Config(BaseModel):
    """Experiment 6 configuration."""
    enabled: bool = True
    description: str = "Analyze benchmark graphs"
    n_sems_per_graph: int = 10
    n_values: List[int] = [500, 1000, 2000, 5000, 10000]
    n_trials: int = 20
    success_threshold: float = 0.9
    graphs: List[BenchmarkGraphInfo] = [
        BenchmarkGraphInfo(name="sachs", description="Protein signaling", d=11, edges=17),
        BenchmarkGraphInfo(name="child", description="Medical diagnosis", d=20, edges=25),
        BenchmarkGraphInfo(name="alarm", description="Medical monitoring", d=37, edges=46),
        BenchmarkGraphInfo(name="insurance", description="Insurance risk", d=27, edges=52),
    ]
    methods: List[str] = ["direct", "curvature"]


class ExperimentsConfig(BaseModel):
    """All experiment configurations."""
    exp1_correlation: Exp1Config = Field(default_factory=Exp1Config)
    exp2_bounds: Exp2Config = Field(default_factory=Exp2Config)
    exp3_proxies: Exp3Config = Field(default_factory=Exp3Config)
    exp4_scaling: Exp4Config = Field(default_factory=Exp4Config)
    exp5_lower_bound: Exp5Config = Field(default_factory=Exp5Config)
    exp6_benchmark: Exp6Config = Field(default_factory=Exp6Config)


class OutputConfig(BaseModel):
    """Output configuration."""
    base_dir: str = "results"
    figures_dir: str = "results/figures"
    tables_dir: str = "results/tables"
    logs_dir: str = "results/logs"
    checkpoints_dir: str = "results/checkpoints"
    figure_format: str = "pdf"
    figure_dpi: int = 300
    save_checkpoints: bool = True
    checkpoint_frequency: int = 100


class ComputationConfig(BaseModel):
    """Computation configuration."""
    n_jobs: Optional[int] = None
    show_progress: bool = True
    verbosity: int = 1
    memory_limit: int = 4096


class LoggingConfig(BaseModel):
    """Logging configuration."""
    level: str = "INFO"
    log_file: str = "results/logs/experiment.log"
    include_timestamps: bool = True
    format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"


class Config(BaseModel):
    """Main configuration class."""
    random_seed: int = 42
    graph: GraphConfig = Field(default_factory=GraphConfig)
    sem: SEMConfig = Field(default_factory=SEMConfig)
    pc: PCConfig = Field(default_factory=PCConfig)
    fisher: FisherConfig = Field(default_factory=FisherConfig)
    experiments: ExperimentsConfig = Field(default_factory=ExperimentsConfig)
    output: OutputConfig = Field(default_factory=OutputConfig)
    computation: ComputationConfig = Field(default_factory=ComputationConfig)
    logging: LoggingConfig = Field(default_factory=LoggingConfig)


# Alias for convenience
ExperimentConfig = ExperimentsConfig


def load_config(path: Union[str, Path] = "config.yml") -> Config:
    """
    Load configuration from YAML file.

    Args:
        path: Path to config file

    Returns:
        Config object
    """
    global _config

    path = Path(path)

    if path.exists():
        with open(path, 'r') as f:
            config_dict = yaml.safe_load(f)
        _config = Config(**config_dict) if config_dict else Config()
    else:
        # Use defaults if no config file
        _config = Config()

    return _config


def get_config() -> Config:
    """
    Get the current configuration.

    Loads from config.yml if not already loaded.

    Returns:
        Config object
    """
    global _config

    if _config is None:
        _config = load_config()

    return _config


def save_config(config: Config, path: Union[str, Path] = "config.yml") -> None:
    """
    Save configuration to YAML file.

    Args:
        config: Config object
        path: Output path
    """
    path = Path(path)

    with open(path, 'w') as f:
        yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)


def config_to_dict(config: Config) -> Dict[str, Any]:
    """Convert config to dictionary."""
    return config.model_dump()


def override_config(
    config: Config,
    overrides: Dict[str, Any]
) -> Config:
    """
    Create new config with overridden values.

    Args:
        config: Base config
        overrides: Values to override

    Returns:
        New Config with overrides applied
    """
    config_dict = config.model_dump()

    def deep_update(d: dict, u: dict) -> dict:
        for k, v in u.items():
            if isinstance(v, dict) and isinstance(d.get(k), dict):
                d[k] = deep_update(d[k], v)
            else:
                d[k] = v
        return d

    deep_update(config_dict, overrides)

    return Config(**config_dict)


def validate_config(config: Config) -> List[str]:
    """
    Validate configuration for common issues.

    Args:
        config: Config to validate

    Returns:
        List of warning messages
    """
    warnings = []

    # Check beta range
    if config.sem.beta_min > config.sem.beta_max:
        warnings.append("sem.beta_min > sem.beta_max")

    # Check sigma range
    if config.sem.sigma_min > config.sem.sigma_max:
        warnings.append("sem.sigma_min > sem.sigma_max")

    # Check experiment parameters
    if config.experiments.exp1_correlation.graphs_per_family < 10:
        warnings.append("exp1: graphs_per_family is very low")

    if config.experiments.exp1_correlation.n_trials < 10:
        warnings.append("exp1: n_trials is very low")

    # Check output directories
    output_dirs = [
        config.output.base_dir,
        config.output.figures_dir,
        config.output.tables_dir,
        config.output.logs_dir,
    ]
    for dir_path in output_dirs:
        if ".." in dir_path:
            warnings.append(f"Potentially unsafe path: {dir_path}")

    return warnings
