import logging
import warnings

import numpy as np
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.cluster import KMeans, DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.ensemble import IsolationForest, RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    jaccard_score,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

warnings.filterwarnings("ignore")


class BaseModel:

    def __init__(self, name: str):
        self.name = name
        self.model = None
        self.is_unsupervised = False

    def _pool_features(self, ops_tensor):
        """
        Pool operation-level features into step-level representations.

        Args:
            ops_tensor (np.ndarray): Tensor of shape (B, T, F).

        Returns:
            np.ndarray: Pooled step-level features of shape (B, F).
        """
        mask = ~np.isnan(ops_tensor).any(axis=-1)
        masked = ops_tensor.copy()
        masked[~mask] = 0.0
        lengths = mask.sum(axis=1, keepdims=True).clip(min=1)
        pooled = masked.sum(axis=1) / lengths
        return pooled

    def fit(self, train_loader):
        """
        Fit model to training data.

        Args:
            train_loader: DataLoader providing batches.

        Returns:
            self
        """
        X, y = [], []
        for batch in train_loader:
            pooled = self._pool_features(batch["ops_tensor"])
            X.append(pooled)
            y.append(batch["step_labels"])
        X = np.vstack(X)
        y = np.concatenate(y)
        if hasattr(self.model, "fit"):
            try:
                self.model.fit(X, y)
            except TypeError:
                self.model.fit(X)
        return self

    def predict(self, dataloader):
        """
        Predict labels and scores for given data.

        Args:
            dataloader: DataLoader providing batches.

        Returns:
            dict: Prediction results.
        """
        step_preds, step_scores = [], []
        op_preds, op_scores = [], []

        for batch in dataloader:
            pooled = self._pool_features(batch["ops_tensor"])
            if hasattr(self.model, "predict"):
                preds = self.model.predict(pooled)
            else:
                preds = self.model.fit_predict(pooled)

            if self.is_unsupervised:
                if self.name == "IsolationForest":
                    preds = np.where(preds == -1, 0, 1)
                else:
                    preds = np.where(preds == 0, 0, 1)

            if hasattr(self.model, "decision_function"):
                scores = self.model.decision_function(pooled)
            elif hasattr(self.model, "predict_proba"):
                proba = self.model.predict_proba(pooled)
                scores = proba[:, 1] if proba.shape[1] > 1 else proba[:, 0]
            else:
                scores = np.zeros_like(preds, dtype=float)

            step_preds.append(preds)
            step_scores.append(scores)

            _, T, _ = batch["ops_tensor"].shape
            op_pred_batch = np.repeat(preds[:, None], T, axis=1).ravel()
            op_score_batch = np.repeat(scores[:, None], T, axis=1).ravel()
            op_preds.append(op_pred_batch)
            op_scores.append(op_score_batch)

        return {
            "step_preds": np.concatenate(step_preds),
            "step_scores": np.concatenate(step_scores),
            "op_preds": np.concatenate(op_preds),
            "op_scores": np.concatenate(op_scores),
        }

    def evaluate(self, dataloader):
        """
        Evaluate model performance on given data.

        Returns:
            dict:
              {
                "step": {acc, prec, rec, f1},
                "op": {
                  "macro": {prec, rec, f1, jaccard},
                  "micro": {prec, rec, f1, jaccard},
                  "macro_positive_only": {prec, rec, f1, jaccard}
                },
                "op_details_len": <int> 
              }
        """
        import numpy as np
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

        preds = self.predict(dataloader)
        y_pred_step = preds["step_preds"]  # shape: [N]
        op_pred_all = preds["op_preds"]  # shape: [N,D] 或 [N*D]

        y_true_step_list = []
        true_mats = [] 
        masks = [] 
        shapes = []  

        for batch in dataloader:
            y_true_step_list.append(batch["step_labels"])  # [B]

            op_labels = batch["op_labels"]  # [B, D] with -100
            m = (op_labels != -100)
            op_labels = np.where(m, op_labels, 0)
            true_mats.append(op_labels.astype(int))
            masks.append(m.astype(bool))
            shapes.append(op_labels.shape)

        y_true_step = np.concatenate(y_true_step_list, axis=0)  # [N]

        pred_mats = []
        if op_pred_all.ndim == 1:
            # 展平的一维：按每批 (B,D) 形状切片再 reshape
            ptr = 0
            for (B, D) in shapes:
                seg = op_pred_all[ptr:ptr + B * D]
                pred_mats.append(seg.reshape(B, D))
                ptr += B * D
        else:
            # 二维：按 batch size 在 axis=0 分段；D 需一致
            ptr = 0
            for (B, D) in shapes:
                pred_mats.append(op_pred_all[ptr:ptr + B, :D])
                ptr += B

        step_acc = accuracy_score(y_true_step, y_pred_step)
        step_prec = precision_score(y_true_step, y_pred_step, zero_division=0)
        step_rec = recall_score(y_true_step, y_pred_step, zero_division=0)
        step_f1 = f1_score(y_true_step, y_pred_step, zero_division=0)

        op_details = []  
        for true_mat, pred_mat, m in zip(true_mats, pred_mats, masks):
            B, D = true_mat.shape
            for i in range(B):
                valid = m[i]  
                true_row = true_mat[i][valid] 
                pred_row = pred_mat[i][valid]

                label_set = set(np.where(true_row == 1)[0].tolist())
                pred_set = set(np.where(pred_row == 1)[0].tolist())

                op_details.append({"pred": pred_set, "label": label_set})

        def set_metrics(pred_set, label_set):
            pred, true = set(pred_set), set(label_set)
            tp = len(pred & true)
            fp = len(pred - true)
            fn = len(true - pred)
            prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
            jac = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
            return {"prec": prec, "rec": rec, "f1": f1, "jaccard": jac}

        def avg_dict(dicts):
            if not dicts:
                return {"prec": 0.0, "rec": 0.0, "f1": 0.0, "jaccard": 0.0}
            keys = dicts[0].keys()
            out = {k: 0.0 for k in keys}
            for d in dicts:
                for k in keys:
                    out[k] += d.get(k, 0.0)
            n = len(dicts)
            for k in keys:
                out[k] /= n
            return out

        per_sample = [set_metrics(x["pred"], x["label"]) for x in op_details]
        op_macro = avg_dict(per_sample)

        union_pred = set().union(*(x["pred"] for x in op_details)) if op_details else set()
        union_true = set().union(*(x["label"] for x in op_details)) if op_details else set()
        op_micro = set_metrics(union_pred, union_true)

        pos_only = [set_metrics(x["pred"], x["label"]) for x in op_details if len(x["label"]) > 0]
        op_macro_pos = avg_dict(pos_only)

        logger.info("=== Evaluation for %s ===", self.name)
        logger.info(
            "Step-level -> acc: %.4f, prec: %.4f, rec: %.4f, f1: %.4f",
            step_acc, step_prec, step_rec, step_f1
        )
        logger.info(
            "Op-level   -> MACRO:    prec: %.4f, rec: %.4f, f1: %.4f, jaccard: %.4f",
            op_macro["prec"], op_macro["rec"], op_macro["f1"], op_macro["jaccard"]
        )
        logger.info(
            "Op-level   -> MICRO:    prec: %.4f, rec: %.4f, f1: %.4f, jaccard: %.4f",
            op_micro["prec"], op_micro["rec"], op_micro["f1"], op_micro["jaccard"]
        )
        logger.info(
            "Op-level   -> POS-ONLY: prec: %.4f, rec: %.4f, f1: %.4f, jaccard: %.4f",
            op_macro_pos["prec"], op_macro_pos["rec"], op_macro_pos["f1"], op_macro_pos["jaccard"]
        )

        return {
            "step": dict(acc=step_acc, prec=step_prec, rec=step_rec, f1=step_f1),
            "op": {
                "macro": op_macro,
                "micro": op_micro,
                "macro_positive_only": op_macro_pos,
            },
            "op_details_len": len(op_details),
        }


class KMeansModel(BaseModel):
    def __init__(
        self, n_clusters=8, init="k-means++", n_init=10,
        max_iter=300, tol=1e-4, random_state=None
    ):
        super().__init__("KMeans")
        self.is_unsupervised = True
        self.model = KMeans(
            n_clusters=n_clusters,
            init=init,
            n_init=n_init,
            max_iter=max_iter,
            tol=tol,
            random_state=random_state,
        )


class IsolationForestModel(BaseModel):
    def __init__(
        self, n_estimators=100, max_samples="auto",
        contamination="auto", max_features=1.0,
        bootstrap=False, random_state=None
    ):
        super().__init__("IsolationForest")
        self.is_unsupervised = True
        self.model = IsolationForest(
            n_estimators=n_estimators,
            max_samples=max_samples,
            contamination=contamination,
            max_features=max_features,
            bootstrap=bootstrap,
            random_state=random_state,
        )


class DBSCANModel(BaseModel):
    def __init__(self, eps=0.5, min_samples=5, metric="euclidean"):
        super().__init__("DBSCAN")
        self.is_unsupervised = True
        self.model = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)


class GMMModel(BaseModel):
    def __init__(
        self, n_components=2, covariance_type="full",
        tol=1e-3, max_iter=100, random_state=None
    ):
        super().__init__("GMM")
        self.is_unsupervised = True
        self.model = GaussianMixture(
            n_components=n_components,
            covariance_type=covariance_type,
            tol=tol,
            max_iter=max_iter,
            random_state=random_state,
        )


class XGBoostModel(BaseModel):
    def __init__(
        self, n_estimators=100, max_depth=6,
        learning_rate=0.3, subsample=1.0,
        colsample_bytree=1.0, reg_lambda=1.0,
        random_state=0, use_label_encoder=False,
        eval_metric="logloss"
    ):
        super().__init__("XGBoost")
        self.model = XGBClassifier(
            n_estimators=n_estimators,
            max_depth=max_depth,
            learning_rate=learning_rate,
            subsample=subsample,
            colsample_bytree=colsample_bytree,
            reg_lambda=reg_lambda,
            random_state=random_state,
            use_label_encoder=use_label_encoder,
            eval_metric=eval_metric,
        )


class SVMModel(BaseModel):
    def __init__(
        self, C=1.0, kernel="rbf", degree=3,
        gamma="scale", probability=True,
        random_state=None
    ):
        super().__init__("SVM")
        self.model = SVC(
            C=C,
            kernel=kernel,
            degree=degree,
            gamma=gamma,
            probability=probability,
            random_state=random_state,
        )


class RandomForestModel(BaseModel):
    def __init__(
        self, n_estimators=100, max_depth=None,
        max_features="sqrt", bootstrap=True,
        random_state=None
    ):
        super().__init__("RandomForest")
        self.model = RandomForestClassifier(
            n_estimators=n_estimators,
            max_depth=max_depth,
            max_features=max_features,
            bootstrap=bootstrap,
            random_state=random_state,
        )


if __name__ == "__main__":
    from dataset import build_dataloader

    train_loader, valid_loader, test_loader = build_dataloader(
        "vertical", batch_size=16, train_valid_test_rate=[0.6, 0.2, 0.2]
    )

    models = [
        KMeansModel(),
        IsolationForestModel(),
        DBSCANModel(),
        GMMModel(),
        XGBoostModel(),
        SVMModel(),
        RandomForestModel(),
    ]

    for m in models:
        logger.info("-" * 60)
        logger.info("Testing model: %s", m.name)
        try:
            m.fit(train_loader)
            preds = m.predict(test_loader)
            logger.info("Prediction outputs shapes/types for %s:", m.name)
            for k, v in preds.items():
                logger.info("  %s: %s (type: %s)", k, v.shape, type(v))
            m.evaluate(test_loader)
        except Exception as e:
            logger.error("Error during predict/evaluate for %s: %s", m.name, e, exc_info=True)

    logger.info("All done.")
