import os
import json
import logging
import matplotlib.pyplot as plt

from dataset import build_dataloader
from model import (
    KMeansModel,
    IsolationForestModel,
    DBSCANModel,
    GMMModel,
    XGBoostModel,
    SVMModel,
    RandomForestModel,
)
from setting import *

# ===========================
# Logging Configuration
# ===========================

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

# ===========================
# Output Directory
# ===========================

OUTPUT_DIR = os.path.join("results", "infer")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ===========================
# Core Functions
# ===========================


def _set_metrics(pred_set, label_set):
    pred = set(pred_set or [])
    true = set(label_set or [])
    tp = len(pred & true)
    fp = len(pred - true)
    fn = len(true - pred)
    # precision/recall
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec) > 0 else 0.0
    jac  = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
    return {"prec": prec, "rec": rec, "f1": f1, "jaccard": jac}


def _avg_dict(dicts):
    if not dicts:
        return {"prec": 0.0, "rec": 0.0, "f1": 0.0, "jaccard": 0.0}
    keys = dicts[0].keys()
    out = {k: 0.0 for k in keys}
    for d in dicts:
        for k in keys:
            out[k] += d.get(k, 0.0)
    n = len(dicts)
    for k in keys:
        out[k] /= n
    return out


def _aggregate_macro(op_details):
    per = [_set_metrics(x.get("pred"), x.get("label")) for x in (op_details or [])]
    return _avg_dict(per)



def _aggregate_micro(op_details):
    union_pred = set()
    union_true = set()
    for x in (op_details or []):
        union_pred |= set(x.get("pred") or [])
        union_true |= set(x.get("label") or [])
    return _set_metrics(union_pred, union_true)


def _aggregate_macro_positive_only(op_details):
    filtered = [x for x in (op_details or []) if x.get("label")]
    if not filtered:
        return {"prec": 0.0, "rec": 0.0, "f1": 0.0, "jaccard": 0.0}
    per = [_set_metrics(x.get("pred"), x.get("label")) for x in filtered]
    return _avg_dict(per)


def _enrich_op_metrics(result_dict):
    op = result_dict.get("op", {})
    if isinstance(op.get("macro"), dict) and isinstance(op.get("micro"), dict):
        return result_dict  

    op_details = result_dict.get("op_details")

    if op and not isinstance(op.get("macro"), dict):
        macro = op  
        result_dict["op"] = {"macro": macro}

    if op_details:
        result_dict["op"]["macro"] = _aggregate_macro(op_details)
        result_dict["op"]["micro"] = _aggregate_micro(op_details)
        result_dict["op"]["macro_positive_only"] = _aggregate_macro_positive_only(op_details)

    return result_dict


def run_task(task: str):
    results = {}
    logging.info(f"===== Start Task: {task} =====")

    train_loader, valid_loader, test_loader = build_dataloader(task, 32, [0.7, 0.2, 0.1])

    models = [
        (
            "KMeans",
            KMeansModel(
                n_clusters=KMEANS_N_CLUSTERS,
                init=KMEANS_INIT,
                max_iter=KMEANS_MAX_ITER,
                random_state=KMEANS_RANDOM_STATE,
            ),
        ),
        (
            "IsolationForest",
            IsolationForestModel(
                n_estimators=ISOLATIONFOREST_N_ESTIMATORS,
                max_samples=ISOLATIONFOREST_MAX_SAMPLES,
                contamination=ISOLATIONFOREST_CONTAMINATION,
                max_features=ISOLATIONFOREST_MAX_FEATURES,
                random_state=ISOLATIONFOREST_RANDOM_STATE,
            ),
        ),
        (
            "DBSCAN",
            DBSCANModel(
                eps=DBSCAN_EPS,
                min_samples=DBSCAN_MIN_SAMPLES,
                metric=DBSCAN_METRIC,
            ),
        ),
        (
            "GMM",
            GMMModel(
                n_components=GMM_N_COMPONENTS,
                covariance_type=GMM_COVARIANCE_TYPE,
                max_iter=GMM_MAX_ITER,
                random_state=GMM_RANDOM_STATE,
            ),
        ),
        (
            "XGBoost",
            XGBoostModel(
                n_estimators=XGBOOST_N_ESTIMATORS,
                max_depth=XGBOOST_MAX_DEPTH,
                learning_rate=XGBOOST_LEARNING_RATE,
                subsample=XGBOOST_SUBSAMPLE,
                colsample_bytree=XGBOOST_COLSAMPLE_BYTREE,
                reg_lambda=XGBOOST_REG_LAMBDA,
                random_state=XGBOOST_RANDOM_STATE,
            ),
        ),
        (
            "SVM",
            SVMModel(
                C=SVM_C,
                kernel=SVM_KERNEL,
                gamma=SVM_GAMMA,
            ),
        ),
        (
            "RandomForest",
            RandomForestModel(
                n_estimators=RANDOMFOREST_N_ESTIMATORS,
                max_depth=RANDOMFOREST_MAX_DEPTH,
                max_features=RANDOMFOREST_MAX_FEATURES,
                bootstrap=RANDOMFOREST_BOOTSTRAP,
                random_state=RANDOMFOREST_RANDOM_STATE,
            ),
        ),
    ]

    for name, model in models:
        logging.info(f"[{name}] Training started")
        model.fit(train_loader)
        model.predict(test_loader)
        res = model.evaluate(test_loader)
        res = _enrich_op_metrics(res)
        results[name] = res
        logging.info(f"[{name}] Training & Evaluation finished")

    logging.info(f"===== Task {task} Finished =====")
    return results


def plot_metrics(task_name: str, task_results: dict):
    metrics = ["acc", "prec", "rec", "f1"]
    models = list(task_results.keys())

    fig, axes = plt.subplots(1, len(metrics), figsize=(16, 5))
    for i, m in enumerate(metrics):
        values = [task_results[model]["step"].get(m, float("nan")) for model in models]
        axes[i].bar(models, values)
        axes[i].set_title(f"Step - {m}")
        axes[i].set_xticks(range(len(models)))
        axes[i].set_xticklabels(models, rotation=45, ha="right")
    plt.suptitle(f"{task_name} - Step Metrics")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f"{task_name}_step.png"))
    plt.close()

    op_modes = ["macro", "micro", "macro_positive_only"]
    op_metrics = ["prec", "rec", "f1", "jaccard"]

    for mode in op_modes:
        has_mode = any(isinstance(task_results[m].get("op", {}).get(mode), dict) for m in models)
        if not has_mode:
            if mode != "macro":
                continue

        fig, axes = plt.subplots(1, len(op_metrics), figsize=(20, 5))
        for i, met in enumerate(op_metrics):
            vals = []
            for mdl in models:
                op = task_results[mdl].get("op", {})
                if isinstance(op.get(mode), dict):
                    vals.append(op[mode].get(met, float("nan")))
                else:
                    if mode == "macro":
                        vals.append(op.get(met, float("nan")))
                    else:
                        vals.append(float("nan"))
            axes[i].bar(models, vals)
            axes[i].set_title(f"Op-{mode} - {met}")
            axes[i].set_xticks(range(len(models)))
            axes[i].set_xticklabels(models, rotation=45, ha="right")

        plt.suptitle(f"{task_name} - Op Metrics ({mode})")
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"{task_name}_op_{mode}.png"))
        plt.close()


def main():
    final_results = {}

    for task in ["vertical", "horizontal"]:
        final_results[task] = run_task(task)

    json_path = os.path.join(OUTPUT_DIR, "infer.json")
    with open(json_path, "w") as f:
        json.dump(final_results, f, indent=4)

    logging.info(f"All tasks finished. Results saved to {json_path}")
    print(f"All tasks finished. Results saved to {json_path}")

    for task, task_results in final_results.items():
        plot_metrics(task, task_results)

    logging.info(f"All tasks finished. Results saved to {OUTPUT_DIR}")
    print(f"All tasks finished. Results saved to {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
