import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import time
from tqdm import tqdm
from utils import *
from dataloader import TimeSeriesDataset

class Node:
    def __init__(self, left=None, right=None, r=0, l=0, split_attrib=0, split_value=0.0, depth=0):
        self.left = left
        self.right = right
        self.r = r
        self.l = l
        self.split_attrib = split_attrib
        self.split_value = split_value
        self.k = depth

class HSTree:
    def __init__(self):
        pass

    def _generate_max_min(self, dimensions):
        max_arr = np.zeros((dimensions))
        min_arr = np.zeros((dimensions))
        for q in range(dimensions):
            s_q = np.random.random_sample()
            max_value = max(s_q, 1-s_q)
            max_arr[q] = s_q + 2*max_value
            min_arr[q] = s_q - 2*max_value
        return max_arr, min_arr

    def _BuildSingleHSTree(self, max_arr, min_arr, k, h, dimensions):
        if k == h:
            return Node(depth=k)
        node = Node()
        q = np.random.randint(dimensions) # randomly select a dimension
        p = (max_arr[q] + min_arr[q])/2.0 # split value is the mid-point of the max and min values of the dimension
        temp = max_arr[q]
        max_arr[q] = p
        node.left = self._BuildSingleHSTree(max_arr, min_arr, k+1, h, dimensions)
        max_arr[q] = temp
        min_arr[q] = p
        node.right = self._BuildSingleHSTree(max_arr, min_arr, k+1, h, dimensions)
        node.split_attrib = q
        node.split_value = p
        node.k = k
        return node

    def _UpdateMass(self, x, node, ref_window):
        if(node):
            if(node.k != 0):
                if ref_window:
                    node.r += 1
                else:
                    node.l += 1
            if(x[node.split_attrib] > node.split_value):
                node_new = node.right
            else:
                node_new = node.left
            self._UpdateMass(x, node_new, ref_window)

    def _ScoreTree(self, x, node):
        while node.left and node.right:
            node = node.right if x[node.split_attrib] > node.split_value else node.left
        return node.r * (2 ** node.k)
    
    def _UpdateResetModel(self, node):
        if(node):
            if node.l > 0:
                node.r = node.l
            node.l = 0
            self._UpdateResetModel(node.left)
            self._UpdateResetModel(node.right)

    def PrintTree(self, node):
        if(node):
            print(('Dimension of the node is:%d and split value is:%f, depth is:%d, reference_value:%d') %(node.split_attrib, node.split_value, node.k, node.r))
            self.PrintTree(node.left)
            self.PrintTree(node.right)

    def StreamingHSTrees(self, X, psi, t, h):
        dimensions = X.shape[1]
        score_list = np.zeros((X.shape[0]))
        HSTree_list = []

        print(f'Building trees...')
        for i in range(t):
            max_arr, min_arr = self._generate_max_min(dimensions)
            tree = self._BuildSingleHSTree(max_arr, min_arr, 0, h, dimensions)
            HSTree_list.append(tree)

        print(f'Initial tree updating...')
        for i in range(psi):
            for tree in HSTree_list:
                self._UpdateMass(X[i], tree, True)

        count = 0
        pbar = tqdm(range(psi, X.shape[0]), desc='Processing instances')
        for i in pbar:
            x = X[i]
            s = 0
            for tree in HSTree_list:
                s = s + self._ScoreTree(x, tree)
                self._UpdateMass(x, tree, False)
            # print(('Score is %f for instance %d') %(s, i))
            score_list[i] = s
            count += 1

            if count == psi:
                # tqdm.write('Reset tree')
                for tree in HSTree_list:
                    self._UpdateResetModel(tree)
                count = 0
        return score_list
    
class HSTreeAnomalyDetector(HSTree):
    def __init__(self):
        super().__init__()
        self.HSTree = HSTree()

    def fit(self, X_train, y_train=None, psi=250, t=25, h=15):
        self.psi = psi
        self.t = t
        self.h = h

        scores = self.HSTree.StreamingHSTrees(X_train, self.psi, self.t, self.h)

        if np.any(scores <= 0):
            scores[scores <= 0] = 1e-10  # Avoid division by zero
  
        reversed_scores = np.reciprocal(scores, where=scores!=0)

        return reversed_scores
  
    def predict_score(self, X):
        scores = self.HSTree.StreamingHSTrees(X, self.psi, self.t, self.h)

        if np.any(scores <= 0):
            scores[scores <= 0] = 1e-10  # Avoid division by zero

        reversed_scores = np.reciprocal(scores, where=scores!=0)

        return reversed_scores
    
if __name__ == "__main__":
    set_seed(42)

    # dataset name: SMD, SMAP, MSL, SWaT, WADI
    dataset_name = 'SMD'

    model_name = 'HSTree'

    train_set = TimeSeriesDataset(dataset_name=dataset_name, train=True)
    test_set = TimeSeriesDataset(dataset_name=dataset_name, train=False)
    print(f'Dataset & Model name | {dataset_name} - {model_name}')
    print(f'Anomaly ratio        | {test_set.labels.sum() / len(test_set.labels) * 100:.2f}%')
    print(f'Dataset shape        | train: {train_set.data.shape}, test: {test_set.data.shape}')

    # hyperparameters
    hyperparams = {'HSTree': {
        'psi': 250,  # Number of initial instances to update the model
        't': 25,     # Number of trees
        'h': 15 }     # Height of the trees
    }

    hstree = HSTreeAnomalyDetector()

    # fit
    ts_time = time.time()
    hstree.fit(train_set.data, train_set.labels, **hyperparams[model_name])
    t_time = time.time() - ts_time
    print(f'Training time        | {t_time:.4f} seconds')

    # inference
    is_time = time.time()
    score = hstree.predict_score(test_set.data)
    i_time = time.time() - is_time
    print(f'Inference time       | {i_time:.4f} seconds')

    print(f'Anomaly scores       | Min: {np.min(score):.4f}, Max: {np.max(score):.4f}, Mean: {np.mean(score):.4f}, Std: {np.std(score):.4f}')
    print(f'Metrics              | {cal_metric(y_true=test_set.labels, y_score=score)}')
    print('-' * 100)