from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
import numpy as np

try:
    from sklearn.cluster import KMeans
    from sklearn.mixture import GaussianMixture
except Exception:
    KMeans = None
    GaussianMixture = None


@dataclass
class ClusterConfig:
    method: str = "gmm"
    k: int      = 2
    metric: str = "inc"
    log_transform: bool = False
    random_state: int   = 42
    center_mass_alpha: float = 0.7


class SoftClusterer:
    def __init__(self,
                 method: str,
                 centers: Optional[np.ndarray] = None,
                 gmm: Optional["GaussianMixture"] = None,
                 center_indices: Optional[List[int]] = None,
                 scale: float = 1.0):
        self.method = method
        self.centers = centers
        self.gmm = gmm
        self.center_indices = center_indices or []
        self.scale = float(scale) if scale > 1e-9 else 1.0

    def prob_center(self, x: np.ndarray) -> np.ndarray:
        if self.method == "gmm" and self.gmm is not None:
            proba = self.gmm.predict_proba(x)  # (n,k)
            if not self.center_indices:
                return np.full((x.shape[0],), 0.5, dtype=float)
            return proba[:, self.center_indices].sum(axis=1)

        elif self.method == "kmeans" and self.centers is not None:
            d2 = (x - self.centers.reshape(1, -1)) ** 2  # (n,k)
            logits = - d2 / (2.0 * (self.scale ** 2))
            logits = logits - logits.max(axis=1, keepdims=True)
            expv = np.exp(logits)
            prob = expv / expv.sum(axis=1, keepdims=True)  # (n,k)
            if not self.center_indices:
                return np.full((x.shape[0],), 0.5, dtype=float)
            return prob[:, self.center_indices].sum(axis=1)

        else:
            return np.full((x.shape[0],), 0.5, dtype=float)


def _select_center_indices_by_mass(weights: np.ndarray, alpha: float) -> List[int]:
    order = np.argsort(weights)[::-1]
    sel: List[int] = []
    cum = 0.0
    for i in order:
        sel.append(int(i))
        cum += float(weights[i])
        if cum >= alpha:
            break
    if not sel:
        sel = [int(np.argmax(weights))]
    return sel


def _fit_one_operator(samples: List[float], cfg: ClusterConfig) -> Optional[SoftClusterer]:
    if not samples:
        return None

    X = np.array(samples, dtype=float).reshape(-1, 1)
    if cfg.log_transform:
        X = np.log1p(X)

    n = X.shape[0]
    k_eff = max(1, min(cfg.k, n))

    if cfg.method == "gmm":
        if (GaussianMixture is not None) and (n >= 2):
            gmm = GaussianMixture(n_components=k_eff, random_state=cfg.random_state)
            gmm.fit(X)
            center_indices = _select_center_indices_by_mass(gmm.weights_, cfg.center_mass_alpha)
            return SoftClusterer(method="gmm", gmm=gmm, center_indices=center_indices)
        if KMeans is not None:
            km = KMeans(n_clusters=1, n_init="auto", random_state=cfg.random_state)
            km.fit(X)
            centers = km.cluster_centers_.reshape(-1)
            scale = float(np.std(X)) if float(np.std(X)) > 1e-9 else 1.0
            return SoftClusterer(method="kmeans", centers=centers, center_indices=[0], scale=scale)
        return SoftClusterer(method="none", center_indices=[0])

    elif cfg.method == "kmeans":
        if KMeans is None:
            return SoftClusterer(method="none", center_indices=[0])
        km = KMeans(n_clusters=k_eff, n_init="auto", random_state=cfg.random_state)
        labels = km.fit_predict(X)
        centers = km.cluster_centers_.reshape(-1)
        counts = np.bincount(labels, minlength=k_eff).astype(float)
        weights = counts / counts.sum() if counts.sum() > 0 else counts
        center_indices = _select_center_indices_by_mass(weights, cfg.center_mass_alpha)
        scale = float(np.std(X)) if float(np.std(X)) > 1e-9 else 1.0
        return SoftClusterer(method="kmeans", centers=centers, center_indices=center_indices, scale=scale)

    else:
        # method == "none"
        return SoftClusterer(method="none", center_indices=[0])



def fit_pid_clusters(samples_by_op: Dict[str, List[float]], cfg: ClusterConfig) -> Dict[str, SoftClusterer]:
    models: Dict[str, SoftClusterer] = {}
    for name, samples in samples_by_op.items():
        model = _fit_one_operator(samples, cfg)
        if model is not None:
            models[name] = model
    return models
