import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import importlib
import pandas as pd
from tqdm import tqdm
from methods.pyods import PYOD
from utils import *
from dataloader import *

ML_MODEL_DISPATCH = {'HSTree': ('methods.HSTree', 'HSTreeAnomalyDetector'),
                     'Hotelling': ('methods.hotelling', 'Hotelling')}
DL_MODEL_DISPATCH = {'USAD':  ('methods.USAD',  'USADAnomalyDetector'),
                     'DAGMM': ('methods.DAGMM', 'DAGMMAnomalyDetector'),
                     'LUAD':  ('methods.LUAD',  'LUADAnomalyDetector'),
                     'lstmAE': ('methods.lstmAE', 'LSTMAEAnomalyDetector'),
                     'lstmVAE': ('methods.lstmVAE', 'LSTMVAEAnomalyDetector'),
                     'OmniAnomaly': ('methods.OmniAnomaly', 'OmniAnomalyDetector'),
                     'DeepSVDD': ('methods.DeepSVDD', 'DeepSVDDAnomalyDetector')}

class AnomalyDetector:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.cls_cache = {}
        self.loader_cache = {}

        self.train_set = TimeSeriesDataset(dataset_name=self.dataset_name, train=True)
        self.test_set  = TimeSeriesDataset(dataset_name=self.dataset_name, train=False)

    def _frozendict(self, d: dict):
        # Make a hashable key from loader_config for caching
        if d is None:
            return tuple()
        return tuple(sorted(d.items()))
    
    def _get_class(self, mod_name, cls_name):
        """Cache importlib lookups."""
        key = (mod_name, cls_name)
        if key not in self.cls_cache:
            module = importlib.import_module(mod_name)
            self.cls_cache[key] = getattr(module, cls_name)
        return self.cls_cache[key]

    def _get_loader(self, loader_config):
        """Cache loaders keyed by config; avoid rebuilding identical loaders."""
        key = self._frozendict(loader_config)
        if key not in self.loader_cache:
            self.loader_cache[key] = TimeSeriesLoader(dataset_name=self.dataset_name, **loader_config)
        return self.loader_cache[key]    

    def ml_detector(self, model_name):
        cfg = ModelConfig('ML_models')
        try:
            model_config = cfg.get_param(self.dataset_name)[model_name]
        except Exception:
            model_config = {}

        if model_name in ML_MODEL_DISPATCH:
            mod_name, cls_name = ML_MODEL_DISPATCH[model_name]
            Detector = self._get_class(mod_name, cls_name)
            model = Detector()

            if model_name == 'HSTree':
                model.fit(self.train_set.data, self.train_set.labels, **model_config)
            elif model_name == 'Hotelling':
                model.fit(self.train_set.data, **model_config)
        else:
            model = PYOD(model_name=model_name, seed=get_global_seed())
            model.fit(self.train_set.data, self.train_set.labels, **model_config)

        anomaly_score = model.predict_score(self.test_set.data)

        metrics = cal_metric(self.test_set.labels, anomaly_score)
        return metrics

    def dl_detector(self, model_name, force_retrain=False, auto_train=True):
        cfg = ModelConfig(model_name)
        loader_config, model_config, train_config = cfg.resolve(self.dataset_name)

        loader = self._get_loader(loader_config)
        model_config['input_dim'] = loader.input_dim

        mod_name, cls_name = DL_MODEL_DISPATCH[model_name]
        Detector = self._get_class(mod_name, cls_name)
        model = Detector(loader, **model_config)

        model_path = build_save_path(model_name=model_name,
                                    dataset_name=self.dataset_name,
                                    seed=get_global_seed())
        
        need_train = force_retrain or (auto_train and not os.path.exists(model_path))
        if need_train:
            model.fit(**train_config, data_type="train", save_path=model_path)

        anomaly_scores = model.predict_score(data_type="test", load_path=model_path)
        metrics = cal_metric(loader.test_labels, anomaly_scores)
        return metrics
    
    @staticmethod
    def run(datasets, ml_model_names, dl_model_names, seeds, save_path='./analysis/results/', force_retrain=False, auto_train=True):
        results = []
        header = ['dataset', 'model_name', 'seed', 'auroc', 'aucpr']

        os.makedirs(save_path, exist_ok=True)
        file_path = os.path.join(save_path, 'model_performance.csv')

        DETERMINISTIC_ML = {'HBOS','PCA','LOF','ABOD','Hotelling'}
        DET_SEED = 42

        for dataset in tqdm(datasets, desc="Datasets", leave=True):
            for seed in tqdm(seeds, desc="Seeds", leave=False):
                set_seed(seed)
                detector = AnomalyDetector(dataset)

                for ml_model_name in tqdm(ml_model_names, desc="ML Models", leave=False):
                    if ml_model_name in DETERMINISTIC_ML and seed != DET_SEED:
                        continue
                    try:
                        metrics = detector.ml_detector(ml_model_name)
                        results.append([dataset, ml_model_name, seed, metrics['aucroc'], metrics['aucpr']])
                    except Exception as e:
                        print(f"Error with {ml_model_name} on {dataset} with seed {seed}: {e}")
                        results.append([dataset, ml_model_name, seed, 'N/A', 'N/A'])

                for dl_model_name in tqdm(dl_model_names, desc="DL Models", leave=False):
                    try:
                        metrics = detector.dl_detector(dl_model_name, force_retrain=force_retrain, auto_train=auto_train)
                        results.append([dataset, dl_model_name, seed, metrics['aucroc'], metrics['aucpr']])
                    except Exception as e:
                        print(f"Error with {dl_model_name} on {dataset} with seed {seed}: {e}")
                        results.append([dataset, dl_model_name, seed, 'N/A', 'N/A'])

        df = pd.DataFrame(results, columns=header)
        df.to_csv(file_path, index=False)
        print(f'[SAVED] {file_path}')


if __name__ == "__main__":

    seeds = [41, 42, 43, 44, 45]
    datasets   = ['SMD', 'SMAP', 'MSL', 'SWaT', 'WADI', 'PSM']
    dl_model_names = ['USAD', 'DAGMM', 'LUAD', 'lstmAE', 'lstmVAE', 'OmniAnomaly', 'DeepSVDD']
    ml_model_names = ['HBOS', 'LODA', 'ABOD', 'PCA', 'LOF', 'Hotelling', 'IForest', 'HSTree', 'CBLOF']
    
    AnomalyDetector.run(datasets, ml_model_names, dl_model_names, seeds)