from __future__ import annotations

import hashlib
import json
import math
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from python_src.transforms import parse_affine_transform, parse_io_transform


class FunctionSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    builtin: Optional[str] = None
    expr: Optional[str] = None
    code: Optional[str] = None

    @model_validator(mode="after")
    def validate_one_source(self) -> "FunctionSpec":
        provided = [self.builtin, self.expr, self.code]
        count = sum(value is not None for value in provided)
        if count != 1:
            raise ValueError("Exactly one of builtin/expr/code must be set.")
        if self.code is not None:
            raise ValueError("FunctionSpec.code is not supported yet.")
        return self

    @field_validator("builtin", "expr")
    @classmethod
    def validate_non_empty(cls, value: Optional[str]) -> Optional[str]:
        if value is not None and not value.strip():
            raise ValueError("FunctionSpec strings must be non-empty.")
        return value


class TargetSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    name: str
    function: FunctionSpec
    vars: List[str] = Field(default_factory=lambda: ["x"])

    @field_validator("vars")
    @classmethod
    def validate_vars(cls, value: List[str]) -> List[str]:
        if not value:
            raise ValueError("TargetSpec.vars must be non-empty.")
        if len(set(value)) != len(value):
            raise ValueError("TargetSpec.vars entries must be unique.")
        return value


class PrecisionSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    input_format: str = "fp32"
    compute_format: Optional[str] = None
    output_format: Optional[str] = None
    rounding: str = "rne"
    fp8_variant: Optional[str] = None
    flush_subnormals: bool = False
    overflow: str = "inf_or_saturate"

    @model_validator(mode="after")
    def normalize_formats(self) -> "PrecisionSpec":
        if self.compute_format is None:
            self.compute_format = self.input_format
        if self.output_format is None:
            self.output_format = self.compute_format
        if self.fp8_variant is None:
            for fmt in (self.input_format, self.compute_format, self.output_format):
                if fmt and fmt.startswith("fp8_"):
                    self.fp8_variant = fmt
                    break
        return self


class MetricSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    type: str = "rel"
    threshold: float = 1e-6
    denom_eps: float = 1e-6

    @field_validator("type")
    @classmethod
    def validate_metric_type(cls, value: str) -> str:
        allowed = {"ulp", "abs", "rel"}
        if value not in allowed:
            raise ValueError(f"MetricSpec.type must be one of {sorted(allowed)}.")
        return value


class IntervalSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    start: float
    end: float
    start_open: bool = False
    end_open: bool = False

    @model_validator(mode="after")
    def validate_interval(self) -> "IntervalSpec":
        if self.start >= self.end:
            raise ValueError("IntervalSpec.start must be less than IntervalSpec.end.")
        return self


class PieceSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    piece_id: str
    interval: IntervalSpec
    excluded_points: List[float] = Field(default_factory=list)
    transform: Optional[Dict[str, Any]] = None
    strategy: Optional[Dict[str, Any]] = None

    @field_validator("transform")
    @classmethod
    def validate_transform(cls, value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        if value is None:
            return value
        parse_io_transform(value)
        return value

    @model_validator(mode="after")
    def validate_strategy(self) -> "PieceSpec":
        if not isinstance(self.strategy, dict):
            return self
        mode = str(self.strategy.get("mode", "search")).strip().lower()
        if mode == "mapped":
            source_piece_id = self.strategy.get("source_piece_id")
            if source_piece_id is None or not str(source_piece_id).strip():
                raise ValueError("mapped strategy requires non-empty source_piece_id.")
            parse_affine_transform(
                self.strategy.get("input_transform"),
                field_name=f"strategy.input_transform[{self.piece_id}]",
            )
            parse_affine_transform(
                self.strategy.get("output_transform"),
                field_name=f"strategy.output_transform[{self.piece_id}]",
            )
        return self


class DomainSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    pieces: List[PieceSpec] = Field(min_length=1)

    @model_validator(mode="after")
    def validate_piece_ids(self) -> "DomainSpec":
        piece_ids = [piece.piece_id for piece in self.pieces]
        if len(set(piece_ids)) != len(piece_ids):
            raise ValueError("DomainSpec.pieces must have unique piece_id values.")
        return self


class EdgeFocusSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    enabled: bool = False
    ratio: float = 0.1

    @field_validator("ratio")
    @classmethod
    def validate_ratio(cls, value: float) -> float:
        if value < 0 or value > 1:
            raise ValueError("EdgeFocusSpec.ratio must be in [0, 1].")
        return value


class SamplingSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    n_data: int = 5000
    mode: str = "uniform"
    edge_focus: EdgeFocusSpec = Field(default_factory=EdgeFocusSpec)
    focus_points: List[float] = Field(default_factory=list)
    focus_radius: float = 1e-3
    seed: int = 1234567

    @field_validator("n_data")
    @classmethod
    def validate_n_data(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("SamplingSpec.n_data must be positive.")
        return value


class MutationConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    rate: float = 0.5
    steps_per_mutation: int = 20
    operator_probs: Dict[str, float] = Field(
        default_factory=lambda: {
            "insertion": 0.15,
            "deletion": 0.35,
            "reconnection": 0.5,
        }
    )

    @field_validator("rate")
    @classmethod
    def validate_rate(cls, value: float) -> float:
        if value < 0 or value > 1:
            raise ValueError("MutationConfig.rate must be in [0, 1].")
        return value

    @field_validator("steps_per_mutation")
    @classmethod
    def validate_steps(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("MutationConfig.steps_per_mutation must be positive.")
        return value


class CrossoverConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    rate: float = 0.2
    method: str = "subgraph"

    @field_validator("rate")
    @classmethod
    def validate_rate(cls, value: float) -> float:
        if value < 0 or value > 1:
            raise ValueError("CrossoverConfig.rate must be in [0, 1].")
        return value


class SelectionConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    method: str = "tournament"
    k: int = 3

    @field_validator("k")
    @classmethod
    def validate_k(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("SelectionConfig.k must be positive.")
        return value


class MultiObjectiveConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    enabled: bool = True
    objectives: List[str] = Field(default_factory=lambda: ["error", "ops"])
    algo: str = "nsga2"


class EvolutionConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    population_size: Optional[int] = None
    num_mantain: int = 40
    mutation: MutationConfig = Field(default_factory=MutationConfig)
    crossover: CrossoverConfig = Field(default_factory=CrossoverConfig)
    selection: SelectionConfig = Field(default_factory=SelectionConfig)
    multiobjective: MultiObjectiveConfig = Field(default_factory=MultiObjectiveConfig)

    @field_validator("population_size", "num_mantain")
    @classmethod
    def validate_positive_ints(cls, value: Optional[int]) -> Optional[int]:
        if value is None:
            return value
        if value <= 0:
            raise ValueError("EvolutionConfig values must be positive.")
        return value


class CMAConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    enabled: bool = False
    popsize: int = 128
    maxiter: int = 1000
    sigma0: float = 0.002
    seed: int = 1234567

    @field_validator("popsize", "maxiter")
    @classmethod
    def validate_positive_ints(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("CMAConfig values must be positive.")
        return value


class LMConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    enabled: bool = True
    max_iters: int = 50
    pop_size: int = 16
    max_nfev: int = 100

    @field_validator("max_iters", "pop_size", "max_nfev")
    @classmethod
    def validate_positive_ints(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("LMConfig values must be positive.")
        return value


class NelderMeadConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    enabled: bool = True
    max_iters: int = 100
    xatol: float = 1e-6
    fatol: float = 1e-6

    @field_validator("max_iters")
    @classmethod
    def validate_positive_ints(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("NelderMeadConfig.max_iters must be positive.")
        return value


class OptimizerConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    cma: CMAConfig = Field(default_factory=CMAConfig)
    lm: LMConfig = Field(default_factory=LMConfig)
    nelder_mead: NelderMeadConfig = Field(default_factory=NelderMeadConfig)


class SearchConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    evolution: EvolutionConfig = Field(default_factory=EvolutionConfig)
    optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)


class VerifyPassSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    level: int = 2
    metric: str = "ulp_p99"
    threshold: float = 1.0

    @field_validator("level")
    @classmethod
    def validate_level(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("VerifyPassSpec.level must be positive.")
        return value


class NoImproveSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    rounds: int = 20
    min_delta: float = 1e-9

    @field_validator("rounds")
    @classmethod
    def validate_rounds(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("NoImproveSpec.rounds must be positive.")
        return value


class StopCriteria(BaseModel):
    model_config = ConfigDict(extra="forbid")

    verify_pass: Optional[VerifyPassSpec] = None
    max_wall_time_s: int = 172800
    no_improve: NoImproveSpec = Field(default_factory=NoImproveSpec)
    max_pieces: int = 20
    min_piece_width: float = 1e-6

    @field_validator("max_wall_time_s", "max_pieces")
    @classmethod
    def validate_positive_ints(cls, value: int) -> int:
        if value <= 0:
            raise ValueError("StopCriteria values must be positive.")
        return value


class ApproxSpec(BaseModel):
    model_config = ConfigDict(extra="forbid")

    schema_version: str = "0.1"
    request_id: Optional[str] = None
    target: TargetSpec
    precision_model: PrecisionSpec = Field(default_factory=PrecisionSpec)
    metric: MetricSpec = Field(default_factory=MetricSpec)
    domain: DomainSpec
    sampling: SamplingSpec = Field(default_factory=SamplingSpec)
    search_config: SearchConfig = Field(default_factory=SearchConfig)
    stop_criteria: StopCriteria = Field(default_factory=StopCriteria)

    @field_validator("schema_version")
    @classmethod
    def validate_schema_version(cls, value: str) -> str:
        if value != "0.1":
            raise ValueError("Only schema_version '0.1' is supported.")
        return value


def normalize_spec(spec: ApproxSpec) -> ApproxSpec:
    from python_src.config import default_lm_max_iters

    normalized = spec.model_copy(deep=True)

    for piece in normalized.domain.pieces:
        interval = piece.interval
        if interval.start_open and math.isfinite(interval.start):
            interval.start = math.nextafter(interval.start, math.inf)
            interval.start_open = False
        if interval.end_open and math.isfinite(interval.end):
            interval.end = math.nextafter(interval.end, -math.inf)
            interval.end_open = False
        if interval.start >= interval.end:
            raise ValueError(
                f"IntervalSpec invalid after normalization for piece_id={piece.piece_id}."
            )

    mutation = normalized.search_config.evolution.mutation
    if mutation.operator_probs:
        for key, value in mutation.operator_probs.items():
            if value < 0:
                raise ValueError(f"MutationConfig.operator_probs[{key}] must be non-negative.")
        total = sum(mutation.operator_probs.values())
        if total <= 0:
            raise ValueError("MutationConfig.operator_probs must have a positive sum.")
        if not math.isclose(total, 1.0, rel_tol=1e-6, abs_tol=1e-12):
            mutation.operator_probs = {
                key: value / total for key, value in mutation.operator_probs.items()
            }

    search_fields = getattr(spec.search_config, "model_fields_set", set())
    optimizer_fields = getattr(spec.search_config.optimizer, "model_fields_set", set())
    lm_fields = getattr(spec.search_config.optimizer.lm, "model_fields_set", set())
    lm_max_iters_provided = (
        "search_config" in getattr(spec, "model_fields_set", set())
        and "optimizer" in search_fields
        and "lm" in optimizer_fields
        and "max_iters" in lm_fields
    )
    if not lm_max_iters_provided:
        normalized.search_config.optimizer.lm.max_iters = default_lm_max_iters(normalized.precision_model)

    if not normalized.request_id:
        payload = normalized.model_dump(mode="json", exclude={"request_id"})
        raw = json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
        normalized.request_id = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]

    return normalized
