# script/processors/claritree.py
from __future__ import annotations
import numpy as np
from typing import Optional, Dict, Any
from clari_tree import (
    Greedy, CLARITree,
    GreedyConst,   CLARITreeConst,
)

from .base import Processor, FitArtifacts

# ===== Defaults & helpers =====
DEFAULTS: Dict[str, Any] = dict(
    verbose=False,
    cost_complexity=0.0,   
    stride=1,
    ridge_penalty=0.0,     
)

def _get_depth(hp: Dict[str, Any], default: int = 4) -> int:
    return int(hp.get("depth", default))

def _get_stride(hp: Dict[str, Any]) -> int:
    return int(hp.get("stride", DEFAULTS["stride"]))

def _get_verbose(hp: Dict[str, Any]) -> bool:
    return bool(hp.get("verbose", DEFAULTS["verbose"]))

def _map_hparams_linear(hp: Dict[str, Any]) -> Dict[str, Any]:
    return dict(
        depth=_get_depth(hp),
        lambda_=float(hp.get("cost_complexity", DEFAULTS["cost_complexity"])),
        kappa=float(hp.get("ridge_penalty", DEFAULTS["ridge_penalty"])),
        stride=_get_stride(hp),
        verbose=_get_verbose(hp),
    )

def _map_hparams_const(hp: Dict[str, Any]) -> Dict[str, Any]:
    return dict(
        depth=_get_depth(hp),
        lambda_=float(hp.get("cost_complexity", DEFAULTS["cost_complexity"])),
        verbose=_get_verbose(hp),
    )

# ===== Base processor =====
class _BaseProcessor(Processor):
    """Shared methods for all tree processors."""

    def fit(self, model, X: np.ndarray, y: np.ndarray) -> FitArtifacts:
        model.fit(X, y)
        comp: Optional[int] = None
        if hasattr(model, "n_leaves"):
            try:
                comp = model.n_leaves()
            except Exception:
                comp = None
        return FitArtifacts(model=model, complexity=comp, extras={})

    def predict(self, model, X: np.ndarray) -> np.ndarray:
        return model.predict(X)

# ===== Linear-leaf trees =====
class GreedyProcessor(_BaseProcessor):
    name = "greedy"
    def build(self, **hparams):
        hp = {**DEFAULTS, **hparams}
        kw = _map_hparams_linear(hp)
        return Greedy(**kw)

class CLARITreeProcessor(_BaseProcessor):
    name = "clari_tree"
    def build(self, **hparams):
        hp = {**DEFAULTS, **hparams}
        kw = _map_hparams_linear(hp)
        return CLARITree(**kw)

# ===== Constant-leaf trees =====
class GreedyConstProcessor(_BaseProcessor):
    name = "greedy_const"
    def build(self, **hparams):
        hp = {**DEFAULTS, **hparams}
        kw = _map_hparams_const(hp)
        return GreedyConst(**kw)

class CLARITreeConstProcessor(_BaseProcessor):
    name = "clari_tree_const"
    def build(self, **hparams):
        hp = {**DEFAULTS, **hparams}
        kw = _map_hparams_const(hp)
        return CLARITreeConst(**kw)
