import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from tqdm import tqdm
from methods.pyods import PYOD
from utils import *
from dataloader import *
from time import time

from methods.USAD import USADAnomalyDetector


class RealTimeEstimator:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.train_set = TimeSeriesDataset(dataset_name=self.dataset_name, train=True)
        self.test_set  = TimeSeriesDataset(dataset_name=self.dataset_name, train=False)  
        
    def ml_real_time(self, model_name):
        cfg = ModelConfig('ML_models')
        model_config = cfg.get_param(self.dataset_name)[model_name]

        model = PYOD(model_name=model_name, seed=get_global_seed())

        start_train_time = time()
        model.fit(self.train_set.data, self.train_set.labels, **model_config)
        end_train_time = time()

        start_infer_time = time()
        model.predict_score(self.test_set.data)
        end_infer_time = time()

        train_time = end_train_time - start_train_time
        infer_time = end_infer_time - start_infer_time
        return {'train_time':train_time, 'inference_time':infer_time}

    def dl_real_time(self, model_name):
        cfg = ModelConfig(model_name)
        loader_config, model_config, train_config = cfg.resolve(self.dataset_name)

        loader = TimeSeriesLoader(dataset_name=self.dataset_name, **loader_config)
        model_config['input_dim'] = loader.input_dim

        model = USADAnomalyDetector(loader, **model_config)

        start_train_time = time()
        model.fit(**train_config, data_type='train')
        end_train_time = time()

        start_infer_time = time()
        model.predict_score(data_type='test')
        end_infer_time = time()

        train_time = end_train_time - start_train_time
        infer_time = end_infer_time - start_infer_time
        return {'train_time':train_time, 'inference_time':infer_time}
    
    def run(self, ml_model_names, dl_model_names):
        results = []
        for ml_model_name in tqdm(ml_model_names, desc="ML Models", leave=False):
            try:
                result = self.ml_real_time(ml_model_name)
                results.append([ml_model_name, result['train_time'], result['inference_time']])
            except Exception as e:
                print(f"Error with ML model {ml_model_name}: {e}")
        for dl_model_name in tqdm(dl_model_names, desc="DL Models", leave=False):
            try:
                result = self.dl_real_time(dl_model_name)
                results.append([dl_model_name, result['train_time'], result['inference_time']])
            except Exception as e:
                print(f"Error with DL model {dl_model_name}: {e}")
        
        df = pd.DataFrame(results, columns=['model_name', 'train_time', 'inference_time'])
        return df
        
    
if __name__ == "__main__":
    set_seed(42)

    dataset_name = 'SMD'
    dl_model_names = ['USAD']
    ml_model_names = ['ABOD', 'HBOS', 'IForest']
    
    estimator = RealTimeEstimator(dataset_name)
    real_time_df = estimator.run(ml_model_names, dl_model_names)