import os
import json
import logging
import traceback
import optuna

from setting import HPO_SEARCH_SPACES, RESULTS_DIR
from model import (
    KMeansModel,
    IsolationForestModel,
    DBSCANModel,
    GMMModel,
    XGBoostModel,
    SVMModel,
    RandomForestModel,
)
from dataset import build_dataloader

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

RESULTS_DIR = os.path.join(RESULTS_DIR, "train")
os.makedirs(RESULTS_DIR, exist_ok=True)

# Cache dataloaders to avoid rebuilding
DATALOADERS = {}

def get_dataloader(task: str, batch_size=32, split=(0.7, 0.2, 0.1)):
    """
    Get (or build) dataloaders for a given task.

    Args:
        task (str): Task name.
        batch_size (int): Batch size for data loading.
        split (tuple): Train/valid/test split ratios.

    Returns:
        tuple: (train_loader, valid_loader, test_loader)
    """
    if task not in DATALOADERS:
        logger.info("Initializing dataloader: %s", task)
        train_loader, valid_loader, test_loader = build_dataloader(
            task, batch_size=batch_size, train_valid_test_rate=split
        )
        DATALOADERS[task] = (train_loader, valid_loader, test_loader)
    return DATALOADERS[task]


# Register available models
MODEL_REGISTRY = {
    "KMeans": KMeansModel,
    "IsolationForest": IsolationForestModel,
    "DBSCAN": DBSCANModel,
    "GMM": GMMModel,
    "XGBoost": XGBoostModel,
    "SVM": SVMModel,
    "RandomForest": RandomForestModel,
}


def objective(trial, task: str, model_name: str):
    """
    Objective function for Optuna optimization.

    Args:
        trial (optuna.Trial): Optuna trial object.
        task (str): Task name.
        model_name (str): Model type.

    Returns:
        float: Negative score (since minimizing).
    """
    ModelClass = MODEL_REGISTRY[model_name]
    search_space = HPO_SEARCH_SPACES.get(model_name, {})
    params = {k: fn(trial) for k, fn in search_space.items()}

    train_loader, _, test_loader = get_dataloader(task)

    try:
        model = ModelClass(**params)
        model.fit(train_loader)
        eval_results = model.evaluate(test_loader)

        step_f1 = eval_results.get("step", {}).get("f1", 0.0)
        op_f1 = eval_results.get("op", {}).get("f1", 0.0)

        score = 0.7 * step_f1 + 0.3 * op_f1
        return -score

    except Exception as e:
        logger.error(
            "Training/evaluation failed - task=%s, model=%s, error=%s",
            task, model_name, e,
            exc_info=True,
        )
        return 1e9


def run_hpo(task: str, model_name: str, n_trials: int = 30):
    """
    Run hyperparameter optimization for a given task and model.

    Args:
        task (str): Task name.
        model_name (str): Model type.
        n_trials (int): Number of Optuna trials.

    Returns:
        tuple: (best_params, best_score)
    """
    result_path = os.path.join(RESULTS_DIR, f"{task}_{model_name}.json")

    if os.path.exists(result_path):
        logger.info("Existing result found, skipping HPO: %s", result_path)
        with open(result_path, "r") as f:
            result = json.load(f)
        return result["best_params"], result["best_score"]

    study = optuna.create_study(direction="minimize")
    study.optimize(lambda trial: objective(trial, task, model_name), n_trials=n_trials)

    best_params, best_score = study.best_params, study.best_value
    result = {"best_params": best_params, "best_score": best_score}

    with open(result_path, "w") as f:
        json.dump(result, f, indent=4, ensure_ascii=False)

    logger.info("Result saved: %s", result_path)
    return best_params, best_score


def main():
    tasks = ["vertical", "horizontal"]
    results = {task: {} for task in tasks}

    for task in tasks:
        logger.info("=== Processing task: %s ===", task)

        for model_name in MODEL_REGISTRY.keys():
            logger.info("--- Optimizing model: %s (task=%s) ---", model_name, task)
            try:
                best_params, best_score = run_hpo(task, model_name, n_trials=20)
                results[task][model_name] = {
                    "best_params": best_params,
                    "best_score": best_score,
                }
            except Exception as e:
                logger.error(
                    "HPO failed - task=%s, model=%s, error=%s",
                    task, model_name, e,
                    exc_info=True,
                )

    output_path = os.path.join(RESULTS_DIR, "all_results.json")
    try:
        with open(output_path, "w") as f:
            json.dump(results, f, indent=4, ensure_ascii=False)
        logger.info("All results saved to %s", output_path)
    except Exception as e:
        logger.error("Failed to save all results: %s", e, exc_info=True)

    logger.info("All tasks completed ✅")


if __name__ == "__main__":
    main()
