from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Optional

from python_src.precision import normalize_precision_model

DEFAULT_OPERATOR_PROBS: Dict[str, float] = {
    "insertion": 0.15,
    "deletion": 0.35,
    "reconnection": 0.5,
}

DEFAULT_LM_MAX_ITERS_BY_PRECISION: Dict[str, int] = {
    "fp8_e4m3fn": 20,
    "fp8_e5m2": 20,
    "fp16": 20,
    "fp32": 20,
    "fp64": 50,
}


def _get_attr(obj: Any, name: str, default: Any) -> Any:
    if obj is None:
        return default
    if isinstance(obj, dict):
        return obj.get(name, default)
    return getattr(obj, name, default)


def _has_explicit_attr(obj: Any, name: str) -> bool:
    if obj is None:
        return False
    if isinstance(obj, dict):
        return name in obj
    fields_set = getattr(obj, "model_fields_set", None)
    if isinstance(fields_set, set):
        return name in fields_set
    return hasattr(obj, name)


def default_lm_max_iters(precision_model: Any = None) -> int:
    precision = normalize_precision_model(precision_model)
    compute_format = precision["compute_format"]
    return int(DEFAULT_LM_MAX_ITERS_BY_PRECISION.get(compute_format, 20))


def normalize_operator_probs(raw: Optional[Dict[str, float]]) -> Dict[str, float]:
    probs = dict(DEFAULT_OPERATOR_PROBS)
    if raw:
        for key, value in raw.items():
            if key in probs:
                probs[key] = float(value)
    total = sum(max(value, 0.0) for value in probs.values())
    if total <= 0:
        return dict(DEFAULT_OPERATOR_PROBS)
    return {key: max(value, 0.0) / total for key, value in probs.items()}


@dataclass
class MutationConfig:
    rate: float = 0.5
    steps_per_mutation: int = 10
    operator_probs: Dict[str, float] = field(default_factory=lambda: dict(DEFAULT_OPERATOR_PROBS))

    @classmethod
    def from_spec(cls, spec: Any) -> "MutationConfig":
        if spec is None:
            return cls()
        return cls(
            rate=float(_get_attr(spec, "rate", cls().rate)),
            steps_per_mutation=int(_get_attr(spec, "steps_per_mutation", cls().steps_per_mutation)),
            operator_probs=normalize_operator_probs(_get_attr(spec, "operator_probs", None)),
        )


@dataclass
class CMAConfig:
    enabled: bool = False
    popsize: int = 128
    maxiter: int = 1000
    sigma0: float = 0.002
    seed: int = 1234567

    @classmethod
    def from_spec(cls, spec: Any) -> "CMAConfig":
        if spec is None:
            return cls()
        return cls(
            enabled=bool(_get_attr(spec, "enabled", cls().enabled)),
            popsize=int(_get_attr(spec, "popsize", cls().popsize)),
            maxiter=int(_get_attr(spec, "maxiter", cls().maxiter)),
            sigma0=float(_get_attr(spec, "sigma0", cls().sigma0)),
            seed=int(_get_attr(spec, "seed", cls().seed)),
        )


@dataclass
class LMConfig:
    enabled: bool = True
    max_iters: int = 20
    pop_size: int = 16
    max_nfev: int = 100

    @classmethod
    def from_spec(cls, spec: Any, precision_model: Any = None) -> "LMConfig":
        default_iters = default_lm_max_iters(precision_model)
        if spec is None:
            return cls(max_iters=default_iters)
        max_iters = (
            int(_get_attr(spec, "max_iters", default_iters))
            if _has_explicit_attr(spec, "max_iters")
            else default_iters
        )
        return cls(
            enabled=bool(_get_attr(spec, "enabled", cls().enabled)),
            max_iters=max_iters,
            pop_size=int(_get_attr(spec, "pop_size", cls().pop_size)),
            max_nfev=int(_get_attr(spec, "max_nfev", cls().max_nfev)),
        )


@dataclass
class NelderMeadConfig:
    enabled: bool = True
    max_iters: int = 100
    xatol: float = 1e-6
    fatol: float = 1e-6

    @classmethod
    def from_spec(cls, spec: Any) -> "NelderMeadConfig":
        if spec is None:
            return cls()
        return cls(
            enabled=bool(_get_attr(spec, "enabled", cls().enabled)),
            max_iters=int(_get_attr(spec, "max_iters", cls().max_iters)),
            xatol=float(_get_attr(spec, "xatol", cls().xatol)),
            fatol=float(_get_attr(spec, "fatol", cls().fatol)),
        )


@dataclass
class OptimizerConfig:
    cma: CMAConfig = field(default_factory=CMAConfig)
    lm: LMConfig = field(default_factory=LMConfig)
    nelder_mead: NelderMeadConfig = field(default_factory=NelderMeadConfig)

    @classmethod
    def from_spec(cls, spec: Any, precision_model: Any = None) -> "OptimizerConfig":
        if spec is None:
            return cls(lm=LMConfig.from_spec(None, precision_model=precision_model))
        return cls(
            cma=CMAConfig.from_spec(_get_attr(spec, "cma", None)),
            lm=LMConfig.from_spec(_get_attr(spec, "lm", None), precision_model=precision_model),
            nelder_mead=NelderMeadConfig.from_spec(_get_attr(spec, "nelder_mead", None)),
        )


@dataclass
class EvolutionConfig:
    population_size: Optional[int] = None
    num_mantain: int = 40
    mutation: MutationConfig = field(default_factory=MutationConfig)
    crossover: Optional[Any] = None
    selection: Optional[Any] = None
    multiobjective: Optional[Any] = None

    @classmethod
    def from_spec(cls, spec: Any) -> "EvolutionConfig":
        if spec is None:
            return cls()
        population_size = _get_attr(spec, "population_size", cls().population_size)
        if population_size is not None:
            population_size = int(population_size)
        return cls(
            population_size=population_size,
            num_mantain=int(_get_attr(spec, "num_mantain", cls().num_mantain)),
            mutation=MutationConfig.from_spec(_get_attr(spec, "mutation", None)),
            crossover=_get_attr(spec, "crossover", None),
            selection=_get_attr(spec, "selection", None),
            multiobjective=_get_attr(spec, "multiobjective", None),
        )


@dataclass
class SearchConfig:
    evolution: EvolutionConfig = field(default_factory=EvolutionConfig)
    optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)

    @classmethod
    def from_spec(cls, spec: Any, precision_model: Any = None) -> "SearchConfig":
        if spec is None:
            return cls(optimizer=OptimizerConfig.from_spec(None, precision_model=precision_model))
        return cls(
            evolution=EvolutionConfig.from_spec(_get_attr(spec, "evolution", None)),
            optimizer=OptimizerConfig.from_spec(
                _get_attr(spec, "optimizer", None),
                precision_model=precision_model,
            ),
        )
