import os
from typing import Iterable, List, Union, Optional, Tuple
import numpy as np
from multiprocessing import Pool
from tqdm import tqdm


class Tree:
    """
    A class representing a node in the decision tree.

    Attributes:
        s (float): The splitting value.
        i (int): The splitting dimension index.
        left (Optional[Tree]): The left child node (subtree).
        right (Optional[Tree]): The right child node (subtree).
    """

    s: float
    i: int
    left: Optional["Tree"]
    right: Optional["Tree"]

    def __init__(
        self, s: float, i: int, left: Optional["Tree"], right: Optional["Tree"]
    ):
        """
        Initialize a Tree node.

        Args:
            s (float): The splitting value.
            i (int): The splitting dimension index.
            left (Optional[Tree]): The left child node (subtree).
            right (Optional[Tree]): The right child node (subtree).
        """
        self.s = s
        self.i = i
        self.left = left
        self.right = right


def generate_tree(X: np.ndarray) -> Optional[Tree]:
    """
    Generate a decision tree recursively based on the given dataset.

    Args:
        X (np.ndarray): The dataset of shape (n_samples, n_features).

    Returns:
        Optional[Tree]: The root node of the tree, or None if the data cannot be split.
    """
    # Only single data point
    if len(X) <= 1 or len(set([tuple(x) for x in X])) == 1:
        return None

    # Randomly choose a dimension
    available_dimensions = list(range(X.shape[1]))
    
    # 使用numpy的操作加速维度选择
    std_vals = np.std(X, axis=0)
    uniq_vals = np.array([len(np.unique(X[:, j])) for j in range(X.shape[1])])
    valid_dims = np.where((std_vals > 0) & (uniq_vals > 1))[0]
    
    if len(valid_dims) == 0:
        return None
        
    i = np.random.choice(valid_dims)

    # Randomly generate a cut point
    m, M = np.min(X[:, i]), np.max(X[:, i])

    # 判断m和M是否接近（可以设置一个小的阈值）
    if np.isclose(m, M, rtol=1e-10, atol=1e-10):
        random_idx = np.random.randint(0, len(X))
        # Put only the randomly selected point on the left, all others on right
        left_mask = np.zeros(len(X), dtype=bool)
        left_mask[random_idx] = True
        right_mask = ~left_mask
        left = X[left_mask]
        right = X[right_mask]
        s = left[0]
    else:
        s = np.random.random() * (M - m) + m
        left_mask = X[:, i] < s
        right_mask = ~left_mask
        left = X[left_mask]
        right = X[right_mask]

    if len(left) == 0 or len(right) == 0:
        # Save X to a file
        np.save("X.npy", X)
        raise ValueError(
            f"Left or right partition is empty. Check the data.\nX: {X}\nm: {m}\nM: {M}\ns: {s}\ni: {i}\nleft: {left}\nright: {right}"
        )

    return Tree(s, i, generate_tree(left), generate_tree(right))


def _get_depth(tree: Optional[Tree], x: np.ndarray, d: int) -> int:
    """
    Recursively compute the depth of a data point in the tree.

    Args:
        tree (Optional[Tree]): The current tree node.
        x (np.ndarray): The data point.
        d (int): The current depth.

    Returns:
        int: The depth of the data point.
    """
    if tree is None:
        return d
    if x[tree.i] < tree.s:
        return _get_depth(tree.left, x, d + 1)
    else:
        return _get_depth(tree.right, x, d + 1)


class IForest:
    """
    Isolation Forest for anomaly detection.

    Attributes:
        trees (List[Tree]): The list of decision trees.
    """

    def __init__(self, n_trees: int = 1000, n_workers: int = 0):
        """
        Initialize the Isolation Forest.

        Args:
            n_trees (int, optional): Number of trees. Defaults to 1000.
            n_workers (int, optional): Number of workers for multiprocessing. Defaults to 0.
        """
        self.n_trees = n_trees
        self.n_workers = os.cpu_count() if n_workers == 0 else n_workers
        self.trees: List[Optional[Tree]] = []

    def fit(self, X: np.ndarray) -> None:
        """
        Fit the Isolation Forest model.

        Args:
            X (np.ndarray): Features of shape (n_samples, n_features).
        """
        assert len(X.shape) == 2, "Input data must be 2-dimensional."

        with Pool(self.n_workers) as p:
            generate_map: Iterable
            if self.n_workers == 1:
                generate_map = map(generate_tree, [X] * self.n_trees)
            else:
                generate_map = p.imap(generate_tree, [X] * self.n_trees)

            self.trees = list(tqdm(generate_map, total=self.n_trees, desc="Fitting"))

    def get_depths(self, X: np.ndarray) -> List[float]:
        """
        Get depths of data points for all trees.

        Args:
            X (np.ndarray): Features of shape (n_samples, n_features).

        Returns:
            List[float]: Average depths of each data point across all trees.
        """
        return [self.get_depth(x) for x in X]

    def get_depth(self, x: np.ndarray) -> float:
        """
        Get the average depth of a single data point across all trees.

        Args:
            x (np.ndarray): A data point of shape (n_features,).

        Returns:
            float: The average depth of the data point.
        """
        depths = [_get_depth(tree, x, 0) for tree in self.trees if tree is not None]
        return sum(depths) / len(depths) if depths else 0.0


if __name__ == "__main__":
    iforest = IForest(n_trees=1000, n_workers=0)

    X = np.random.random((100, 20))
    iforest.fit(X)
    print(iforest.get_depths(X))
