import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import pandas as pd
import importlib
from calflops import calculate_flops
from utils import *
from methods.pyods import PYOD
from dataloader import TimeSeriesDataset, TimeSeriesLoader

class ML_Flops:
    def get(model_name):
        name = str(model_name).strip()
        if not name.startswith("FLOPs_"):
            name = "FLOPs_" + name
        if not hasattr(ML_Flops, name):
            raise KeyError("Unsupported ML FLOPs model: %s" % model_name)
        return getattr(ML_Flops, name)

    def FLOPs_HBOS(num_train, num_inference, input_dim, bins):
        n_tr, n_inf, d, k = num_train, num_inference, input_dim, bins
        flop_train = (2*n_tr*d) + (5*k*d)
        flop_inference = (3*n_inf*d) + (2*k*d)
        return flop_train, flop_inference

    def FLOPs_LODA(num_train, num_inference, input_dim, n_random_cuts, bins):
        n_tr, n_inf, d, k, b = num_train, num_inference, input_dim, n_random_cuts, bins
        flop_train = n_tr*k*(2*np.sqrt(d) + np.log2(b) - 1)
        flop_inference = n_inf*(2*k*np.sqrt(d) + k*np.log2(b) + k + 1)
        return flop_train, flop_inference

    def FLOPs_ABOD(num_train, num_inference, input_dim, n_neighbors):
        n_tr, n_inf, d, k = num_train, num_inference, input_dim, n_neighbors
        flop_train = 1.5*n_tr*(n_tr-1)*d + n_tr*(n_tr-1)*np.log2(n_tr-1) + n_tr*k*(k-1)*(d+2) + n_tr
        flop_inference = 1.5*n_inf*(n_inf-1)*d + n_inf*(n_inf-1)*np.log2(n_inf-1) + n_inf*k*(k-1)*(d+2) + n_inf
        return flop_train, flop_inference

    def FLOPs_PCA(num_train, num_inference, input_dim, n_components):
        n_tr, n_inf, d, k = num_train, num_inference, input_dim, n_components
        flop_train = 2*n_tr*pow(d,2) + 2*n_tr*d + 3*pow(d,2)
        flop_inference = n_inf*k*(2*d-1) + 2*n_inf*d*k - n_inf*d + n_inf*(3*d-1)
        return flop_train, flop_inference

    def FLOPs_LOF(num_train, num_inference, input_dim, n_neighbors):
        n_tr, n_inf, d, k = num_train, num_inference, input_dim, n_neighbors

        flop_train = 1.5*n_tr*(n_tr-1)*d + n_tr*(n_tr-1)*np.log2(n_tr-1) + n_tr*k + n_tr*(k+1) + 2*n_tr*k
        flop_inference = 1.5*n_inf*(n_inf-1)*d + n_inf*(n_inf-1)*np.log2(n_inf-1) + n_inf*k + n_inf*(k+1) + 2*n_inf*k
        return flop_train, flop_inference

    def FLOPs_Hotelling(num_train, num_inference, input_dim):
        n_tr, n_inf, d = num_train, num_inference, input_dim

        flop_train = 2*n_tr*pow(d,2) + 2*n_tr*d + pow(d,3)
        flop_inference = n_inf*(2*pow(d,2) + 2*d - 1)
        return flop_train, flop_inference

    def FLOPs_IForest(num_train, num_inference, n_estimators, max_samples):
        n_tr, n_inf, T, psi = num_train, num_inference, n_estimators, max_samples

        flop_train = T*(2*psi*np.log2(psi))
        flop_inference = n_inf*(T*(2*(np.log(psi-1)+np.euler_gamma) - 2*(1-1/psi)) + (T+2))

        return flop_train, flop_inference

    def FLOPs_HSTree(num_train, num_inference, n_estimators, max_depth, ref_window_size):
        n_tr, n_inf, T, h, psi = num_train, num_inference, n_estimators, max_depth, ref_window_size

        flop_train = T*(psi*(h+1) + 5*(pow(2,h+1)-1))
        flop_inference = n_inf*T*(5*h+7)
        return flop_train, flop_inference

    def FLOPs_CBLOF(train_set, test_set, model_params):
        # train data
        n_tr, d = train_set.train_len, train_set.input_dim
        k = model_params['n_clusters']
        
        cblof = PYOD('CBLOF')
        cblof.fit(train_set.data, **model_params)
        cblof.predict_score(train_set.data)

        num_cluster_iter = cblof.model.clustering_estimator_.n_iter_
        num_instances_in_large_clusters = sum([cblof.model.cluster_sizes_[i] for i in cblof.model.large_cluster_labels_])
        num_instances_in_small_clusters = n_tr - num_instances_in_large_clusters
        num_large_clusters = len(cblof.model.large_cluster_labels_)

        flop_train = num_cluster_iter*(3*k*d*n_tr-n_tr+n_tr*d) + 3*num_instances_in_small_clusters*num_large_clusters*d + 3*num_instances_in_large_clusters*d

        # test data
        n_inf, d = test_set.test_len, test_set.input_dim
        cblof = PYOD('CBLOF')
        cblof.fit(test_set.data, **model_params)
        cblof.predict_score(test_set.data)

        num_cluster_iter = cblof.model.clustering_estimator_.n_iter_
        num_instances_in_large_clusters = sum([cblof.model.cluster_sizes_[i] for i in cblof.model.large_cluster_labels_])
        num_instances_in_small_clusters = n_inf - num_instances_in_large_clusters
        num_large_clusters = len(cblof.model.large_cluster_labels_)

        flop_inference = num_cluster_iter*(3*k*d*n_inf-n_inf+n_inf*d) + 3*num_instances_in_small_clusters*num_large_clusters*d + 3*num_instances_in_large_clusters*d

        return flop_train, flop_inference


class model_flops:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name

        self.MODEL_DISPATCH = {'USAD':  ('methods.USAD',  'USADAnomalyDetector'),
                               'DAGMM': ('methods.DAGMM', 'DAGMMAnomalyDetector'),
                               'LUAD':  ('methods.LUAD',  'LUADAnomalyDetector'),
                               'lstmAE': ('methods.lstmAE', 'LSTMAEAnomalyDetector'),
                               'lstmVAE': ('methods.lstmVAE', 'LSTMVAEAnomalyDetector'),
                               'OmniAnomaly': ('methods.OmniAnomaly', 'OmniAnomalyDetector'),
                               'DeepSVDD': ('methods.DeepSVDD', 'DeepSVDDAnomalyDetector'),
                               'AnomalyTransformer': ('methods.AnomalyTransformer', 'AnomalyTransformerDetector'),
                               'TimesNet': ('methods.TimesNet', 'TimesNetAnomalyDetector')}

    def flops(self, model_name):
        config = ModelConfig(model_name)
        self.loader_config, self.model_config, self.train_config = config.resolve(self.dataset_name)
        self.loader = TimeSeriesLoader(dataset_name=self.dataset_name, **self.loader_config)

        self.model_config['input_dim'] = self.loader.input_dim

        input_dim = self.model_config['input_dim']
        
        if model_name == 'DAGMM' or model_name == 'DeepSVDD':
            input_shape = (1, 1 * input_dim * self.loader_config['window_size'])
        elif model_name == 'TimesNet':
            input_shape = (1, self.loader_config['window_size'], input_dim)
        else:
            input_shape = (1, 1, input_dim)
        
        model_kwargs = self.model_config
        
        mod_name, cls_name = self.MODEL_DISPATCH[model_name]
        module = importlib.import_module(mod_name)
        detector = getattr(module, cls_name)

        model = detector(self.loader, **model_kwargs)
        FLOPs_tr, _, _ = calculate_flops(model.model, input_shape=input_shape, forward_mode='forward', include_backPropagation=True, print_detailed=False, print_results=True, output_as_string=False)
        FLOPs_inf, _, _ = calculate_flops(model.model, input_shape=input_shape, forward_mode='forward', include_backPropagation=False, print_detailed=False, print_results=True, output_as_string=False)
        if model_name in ['DAGMM', 'DeepSVDD', 'TimesNet']:
            FLOPs_tr = FLOPs_tr / self.loader_config['window_size']
            FLOPs_inf = FLOPs_inf / self.loader_config['window_size']
            
        FLOPs_train_epoch = FLOPs_tr * self.loader.train_len
        FLOPs_inference_epoch = FLOPs_inf * self.loader.test_len
        FLOPs_total = FLOPs_train_epoch + FLOPs_inference_epoch
        return FLOPs_train_epoch, FLOPs_inference_epoch, FLOPs_total, FLOPs_train_epoch * self.train_config['epochs']

    def ml_flops(self, model_name):
        ml_model_config = ModelConfig('ML_calflops')
        ml_model_config = ml_model_config.get_param(self.dataset_name)[model_name]
        ml_flops_func = ML_Flops.get(model_name)

        train_dataset = TimeSeriesDataset(self.dataset_name, train=True)
        test_dataset = TimeSeriesDataset(self.dataset_name, train=False)

        if model_name == "IForest" or model_name == 'HSTree':
            ml_model_config['num_train'] = train_dataset.train_len
            ml_model_config['num_inference'] = test_dataset.test_len
        elif model_name == "CBLOF":
            ml_model_config['train_set'] = train_dataset
            ml_model_config['test_set'] = test_dataset
        else:
            ml_model_config['num_train'] = train_dataset.train_len
            ml_model_config['num_inference'] = test_dataset.test_len
            ml_model_config['input_dim'] = train_dataset.input_dim

        train_flops, inference_flops = ml_flops_func(**ml_model_config)
        total_flops = train_flops + inference_flops
        return train_flops, inference_flops, total_flops

    @staticmethod
    def save_flops(datasets, ml_model_names, dl_model_names, save_path='./analysis/results/'):
        flops = []
        header = ['dataset', 'model_name', 'train_flops', 'inference_flops', 'total_flops', 'DL_per_flops']

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        file_path = os.path.join(save_path, 'model_flops.csv')

        for dataset in datasets:
            runner = model_flops(dataset)

            for ml_model_name in ml_model_names:
                try:
                    ml_train_flops, ml_inference_flops, ml_total_flops = runner.ml_flops(ml_model_name)
                    flops.append([dataset, ml_model_name, ml_train_flops, ml_inference_flops, ml_total_flops, ''])
                except Exception as e:
                    print(f'[ML][ERROR] dataset={dataset}, model={ml_model_name}: {e}')
                    flops.append([dataset, ml_model_name, 'N/A', 'N/A', 'N/A', ''])

            for dl_model_name in dl_model_names:
                try:
                    dl_train_flops, dl_inference_flops, dl_total_flops, dl_per_flops = runner.flops(dl_model_name)
                    flops.append([dataset, dl_model_name, dl_train_flops, dl_inference_flops, dl_total_flops, dl_per_flops])
                except Exception as e:
                    print(f'[DL][ERROR] dataset={dataset}, model={dl_model_name}: {e}')
                    flops.append([dataset, dl_model_name, 'N/A', 'N/A', 'N/A', 'N/A'])

        df = pd.DataFrame(flops, columns=header)

        df.to_csv(file_path, index=False)
        print(f'[SAVED] {file_path}')


if __name__ == "__main__":

    datasets   = ['SMD', 'SMAP', 'MSL', 'SWaT', 'WADI', 'PSM']
    dl_model_names = ['USAD', 'DAGMM', 'LUAD', 'lstmAE', 'lstmVAE', 'OmniAnomaly', 'DeepSVDD', 'AnomalyTransformer', 'TimesNet']
    ml_model_names = ['HBOS', 'LODA', 'ABOD', 'PCA', 'LOF', 'Hotelling', 'IForest', 'HSTree', 'CBLOF']

    model_flops.save_flops(datasets, ml_model_names, dl_model_names)