import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import time
import json
import itertools
import numpy as np
from tqdm import tqdm
from methods.pyods import PYOD
from methods.hotelling import Hotelling
from methods.HSTree import HSTreeAnomalyDetector
from utils import *
from dataloader import TimeSeriesDataset

def _load_search_space(search_space_path, model_name):
    if not os.path.exists(search_space_path):
        raise FileNotFoundError(f"Search space file not found: {search_space_path}")
    with open(search_space_path, 'r') as f:
        search_space_all = json.load(f)
    return search_space_all.get(model_name, {})

def _make_param_grid(space_dict):
    """
    {'n_estimators':[100,200], 'bootstrap':[True]} ->
    [{'n_estimators':100,'bootstrap':True},
     {'n_estimators':200,'bootstrap':True}]
    """
    if not space_dict:
        return [dict()]
    keys = list(space_dict.keys())
    values_product = itertools.product(*(space_dict[k] for k in keys))
    return [dict(zip(keys, vals)) for vals in values_product]

def _init_model(model_name, init_params=None):
    if model_name in ['IForest', 'HBOS', 'PCA', 'LODA', 'DeepSVDD', 'LOF', 'CBLOF', 'ABOD']:
        return PYOD(model_name)
    elif model_name == 'HSTree':
        return HSTreeAnomalyDetector()
    elif model_name == 'Hotelling':
        return Hotelling(**init_params)
    else:
        raise ValueError(f"Model {model_name} is not supported for hyperparameter search.")

def run_hyperparameter_search(dataset_name, train_set, test_set, model_name, search_space_path='./methods/configs/ml_search_space.json'):
    # 1) search space -> grid
    space_dict = _load_search_space(search_space_path, model_name)
    grid = _make_param_grid(space_dict)

    best_score = -np.inf
    best_params = None
    train_time_list, infer_time_list = [], []
    score_list = []

    pbar = tqdm(grid, desc=f"[{dataset_name}] {model_name} grid", leave=False)
    for params in pbar:
        if model_name == 'Hotelling':
            model = _init_model(model_name, init_params=params)
        else:
            model = _init_model(model_name)
        input_dim = train_set.data.shape[1]

        # 2) fit
        ts_time = time.time()
        if model_name == 'DeepSVDD':
            model.fit(train_set.data, train_set.labels, n_features=input_dim, **params)
        elif model_name == 'Hotelling':
            model.fit(train_set.data, train_set.labels)
        else:
            model.fit(train_set.data, train_set.labels, **params)
        t_fit = time.time() - ts_time

        # 3) predict scores
        is_time = time.time()
        scores = model.predict_score(test_set.data)
        t_infer = time.time() - is_time

        train_time_list.append(t_fit)
        infer_time_list.append(t_infer)

        aucroc = cal_metric(test_set.labels, scores)['aucroc']
        score_list.append(aucroc)

        pbar.set_postfix(score=f"{aucroc:.4f}", best=f"{max(best_score, aucroc):.4f}")

        if aucroc > best_score:
            best_score = aucroc
            best_params = params

    result = {"best_score": float(best_score),
              "best_params": best_params,
              "avg_train_time_sec": float(np.mean(train_time_list)) if train_time_list else None,
              "avg_infer_time_sec": float(np.mean(infer_time_list)) if infer_time_list else None,
              "num_trials": len(grid),
              "score_list": [float(s) for s in score_list]}

    tqdm.write(f"[{dataset_name} | {model_name}]")
    tqdm.write(f"Scores={score_list}")
    tqdm.write(f"Best score={result['best_score']:.4f}")
    tqdm.write(f"Best params={result['best_params']}")
    tqdm.write(f"Avg train time={result['avg_train_time_sec']:.4f}s")
    tqdm.write(f"Avg infer time={result['avg_infer_time_sec']:.4f}s")

    return result

def main():
    datasets = ['SMD', 'SMAP', 'MSL', 'SWaT', 'WADI', 'PSM']
    models = ['IForest', 'HBOS', 'PCA', 'LODA', 'DeepSVDD', 'LOF', 'CBLOF', 'ABOD', 'Hotelling']
    search_space_path = './methods/configs/ml_search_space.json'
    result_save_path = './analysis/results'
    savefile_name = 'ML_hyperparameter_search_result.json'

    search_result = {}
    for dataset_name in datasets:
        search_result[dataset_name] = {}
        train_set = TimeSeriesDataset(dataset_name=dataset_name, train=True)
        test_set = TimeSeriesDataset(dataset_name=dataset_name, train=False)
        for model_name in models:
            try:
                res = run_hyperparameter_search(dataset_name, train_set, test_set, model_name, search_space_path)
            except Exception as e:
                tqdm.write(f"[{dataset_name} | {model_name}] ERROR: {e}")
                res = {"error": str(e)}
            search_result[dataset_name][model_name] = res

    print('Hyperparameter search completed for all models on all datasets.')
    with open(f'{result_save_path}/{savefile_name}', 'w') as f:
        json.dump(search_result, f, indent=2)


if __name__ == "__main__":
    set_seed(42)
    main()
