from __future__ import annotations

import json
import math
from pathlib import Path
from typing import Dict, Iterable, List, Sequence

try:
    import numpy as np
except ImportError:  # pragma: no cover - optional dependency
    np = None

try:
    import xgboost as xgb
except ImportError:  # pragma: no cover - optional dependency
    xgb = None

from .feature_store import ZCFeatureStore


class ZCPredictor:
    """Serve accuracy predictions from zero-cost features via XGBoost or linear weights."""

    def __init__(
        self,
        data_path: str | Path | None = None,
        model_path: str | Path | None = None,
        booster: object | None = None,
        feature_order: Sequence[str] | None = None,
    ) -> None:
        self.feature_store = ZCFeatureStore(data_path)
        self.feature_order = list(feature_order) if feature_order is not None else self.feature_store.zc_names
        self.model_type: str | None = None
        self.model: object | None = None
        self.linear_weights: Dict[str, float] | None = None
        self._linear_cache: Dict[str, tuple[float, float]] = {}

        if booster is not None:
            self.set_model(booster)
        if model_path is not None:
            self.load_model(model_path)

    # ------------------------------------------------------------------
    # Model loading
    # ------------------------------------------------------------------
    def set_model(self, booster: object) -> None:
        if _is_xgboost_booster(booster):
            self.model_type = "xgboost"
            self.model = booster
            self._validate_feature_order()
        elif isinstance(booster, dict) and "weights" in booster:
            self.model_type = "linear"
            self.linear_weights = booster["weights"]
            self._validate_weight_keys()
        else:
            raise TypeError("Unsupported predictor model type")

    def load_model(self, model_path: str | Path) -> object:
        path = Path(model_path)
        if path.suffix.lower() == ".json":
            with path.open("r", encoding="utf-8") as handle:
                payload = json.load(handle)
            self.set_model(payload)
            return payload

        if xgb is None:
            raise RuntimeError("xgboost is not installed; unable to load booster model")

        booster = xgb.Booster()
        booster.load_model(str(path))
        self.set_model(booster)
        return booster

    # ------------------------------------------------------------------
    # Prediction API
    # ------------------------------------------------------------------
    def predict_architecture(self, dataset: str, arch: str) -> float:
        predictions = self.predict_batch(dataset, [arch])
        return float(predictions[0])

    def predict_batch(self, dataset: str, archs: Iterable[str]) -> np.ndarray:
        feature_matrix = [
            list(map(float, self.feature_store.get_feature_vector(dataset, arch, self.feature_order)))
            for arch in archs
        ]
        return self._predict_from_features(dataset, feature_matrix)

    def get_training_matrix(self, dataset: str) -> tuple[np.ndarray, np.ndarray]:
        exported = self.feature_store.export_dataset(dataset, self.feature_order)
        features = exported["features"]
        accuracies = exported["accuracies"]
        if np is not None:
            features = np.asarray(features, dtype=float)
            accuracies = np.asarray(accuracies, dtype=float)
        return features, accuracies

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _predict_from_features(self, dataset: str, features: Iterable[Sequence[float]]) -> np.ndarray:
        if self.model_type == "xgboost":
            if xgb is None or np is None or self.model is None:
                raise RuntimeError("XGBoost booster requires xgboost and numpy")
            dmatrix = xgb.DMatrix(np.asarray(list(features), dtype=float), feature_names=self.feature_order)
            return self.model.predict(dmatrix)

        if self.model_type == "linear":
            if self.linear_weights is None:
                raise RuntimeError("Linear weight model not initialised")
            weights_vec = [float(self.linear_weights[name]) for name in self.feature_order]
            raw_scores = [
                sum(f * w for f, w in zip(feature_row, weights_vec))
                for feature_row in features
            ]
            scale, bias = self._get_linear_calibration(dataset, weights_vec)
            predictions = [scale * score + bias for score in raw_scores]
            if np is not None:
                return np.asarray(predictions, dtype=float)
            return predictions

        raise RuntimeError("Predictor model not loaded")

    def _validate_feature_order(self) -> None:
        missing = [name for name in self.feature_order if name not in self.feature_store.zc_names]
        if missing:
            raise ValueError(f"Requested feature order contains unknown features: {missing}")

    def _validate_weight_keys(self) -> None:
        assert self.linear_weights is not None
        unknown = [k for k in self.linear_weights.keys() if k not in self.feature_store.zc_names]
        if unknown:
            raise ValueError(f"Weight model references unknown features: {unknown}")

    def _get_linear_calibration(self, dataset: str, weights_vec: Sequence[float]) -> tuple[float, float]:
        if dataset in self._linear_cache:
            return self._linear_cache[dataset]

        features, accuracies = self.get_training_matrix(dataset)
        if np is not None and isinstance(features, np.ndarray):
            scores = features @ np.asarray(weights_vec, dtype=float)
            accuracies_array = accuracies
        else:
            scores = [
                sum(float(f) * float(w) for f, w in zip(feature_row, weights_vec))
                for feature_row in features
            ]
            accuracies_array = [float(a) for a in accuracies]
        scale, bias = _fit_affine(scores, accuracies_array)
        self._linear_cache[dataset] = (scale, bias)
        return scale, bias


def _fit_affine(xs: Sequence[float], ys: Sequence[float]) -> tuple[float, float]:
    pairs = [
        (float(x), float(y))
        for x, y in zip(xs, ys)
        if math.isfinite(x) and math.isfinite(y)
    ]
    if not pairs:
        return 0.0, 0.0

    n = len(pairs)
    sum_x = sum(x for x, _ in pairs)
    sum_y = sum(y for _, y in pairs)
    mean_x = sum_x / n
    mean_y = sum_y / n
    var_x = sum((x - mean_x) ** 2 for x, _ in pairs)
    if var_x <= 1e-12:
        return 0.0, mean_y
    cov_xy = sum((x - mean_x) * (y - mean_y) for x, y in pairs)
    scale = cov_xy / var_x
    bias = mean_y - scale * mean_x
    return float(scale), float(bias)


def _is_xgboost_booster(obj: object) -> bool:
    if xgb is None:
        return False
    return isinstance(obj, xgb.Booster)


__all__ = ["ZCPredictor"]
