import numpy as np
import os
import matplotlib.pyplot as plt
from multiprocessing import Pool
from typing import Tuple, List, Union


def get_target_depth(X: np.ndarray, target: np.ndarray) -> float:
    """
    Calculate the expected depth of a target point in a dataset using a recursive splitting strategy.

    Args:
        X (np.ndarray): A 2D array of shape (n_samples, n_features) representing the dataset.
        target (np.ndarray): A 1D array of shape (n_features,) representing the target point.

    Returns:
        float: The expected depth of the target point in the dataset.
    """
    if len(set([tuple(x) for x in X])) == 1:
        return 0

    available_dimensions = [i for i in range(X.shape[1]) if X[:, i].std() > 0]
    expected_depth = 0

    for i in available_dimensions:
        expected_depth_i = 0
        X_i = np.sort(np.unique(X[:, i]))
        for j in range(len(X_i) - 1):
            s = (X_i[j] + X_i[j + 1]) / 2
            left, right = X[X[:, i] < s], X[X[:, i] >= s]

            ratio = (X_i[j + 1] - X_i[j]) / (X_i[-1] - X_i[0])
            child_target_depth = 1 + (
                get_target_depth(left, target)
                if target[i] < s
                else get_target_depth(right, target)
            )
            expected_depth_i += ratio * child_target_depth
        expected_depth += expected_depth_i / len(available_dimensions)
    return expected_depth


def get_target_depth_1_arg(args: Tuple[np.ndarray, np.ndarray]) -> float:
    """
    Wrapper for `get_target_depth` to allow single-argument input, useful for multiprocessing.

    Args:
        args (Tuple[np.ndarray, np.ndarray]): A tuple containing the dataset and the target point.

    Returns:
        float: The depth of the target point in the dataset.
    """
    return get_target_depth(*args)


def get_depths(X: np.ndarray) -> List[float]:
    """
    Calculate the depths of all points in a dataset using multiprocessing.

    Args:
        X (np.ndarray): A 2D array of shape (n_samples, n_features) representing the dataset.

    Returns:
        List[float]: A list of depths corresponding to each point in the dataset.
    """
    with Pool(os.cpu_count()) as pool:
        depths = pool.map(get_target_depth_1_arg, [(X, x) for x in X])
    return depths


if __name__ == "__main__":
    X = np.array([[1, 1], [2, 2], [3, 4], [4, 5], [5, 7]])
    depths = get_depths(X)
    depths_x = get_depths(np.expand_dims(X[:, 0], -1))
    depths_y = get_depths(np.expand_dims(X[:, 1], -1))

    print(f"2D depths = {depths}")
    print(f"depths_x = {depths_x}")
    print(f"depths_y = {depths_y}")
    print(f"avg_depth = {(np.array(depths_x) + np.array(depths_y)) / 2}")
    print(f"diff = {np.array(depths) - (np.array(depths_x) + np.array(depths_y)) / 2}")
