import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

# --- Context ---

from sklearn.linear_model import LinearRegression
from georegression.weight_model import WeightModel

distance_measure = "euclidean"
kernel_type = "bisquare"

gwr_neighbour_count=0.2
model = WeightModel(
    LinearRegression(),
    distance_measure,
    kernel_type,
    neighbour_count=gwr_neighbour_count,
)
model.fit(X_plus, y, [points])

print('GWR R2 Score: ', model.llocv_score_)

# --- Alternative ---

from sklearn.metrics import r2_score
y_predict = model.local_predict_
score = r2_score(y, y_predict)
print(score)


import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from georegression.weight_model import WeightModel
from georegression.simulation.simulation_for_importance import coef_auto_gau_weak, coef_auto_gau_strong, f_square_2, generate_sample


def draw_graph():
    X, y, points = generate_sample(
        count=5000, f=f_square_2, coef_func=[coef_auto_gau_strong, coef_auto_gau_weak], random_seed=1,
        plot=True
    )
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.02

    model = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        cache_data=True,
        cache_estimator=True,
    )
    model.fit(X, y, [points])
    print("GRF:", model.llocv_score_)

    importance_global = model.importance_score_global()
    print("Global Importance Score: ", importance_global)

    importance_local = model.importance_score_local()
    print("Local Importance Socre Shape: ", importance_local.shape)

    # Plot the local importance
    for i in range(importance_local.shape[1]):
        fig = plt.figure()
        scatter = plt.scatter(
            points[:, 0], points[:, 1], c=importance_local[:, i], cmap="viridis"
        )
        fig.colorbar(scatter)
        fig.savefig(f"Plot/Local_importance_{i}.png")


if __name__ == "__main__":
    draw_graph()


import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import ExtraTreesRegressor
from georegression.stacking_model import StackingWeightModel

distance_measure = "euclidean"
kernel_type = "bisquare"

stacking_neighbour_count=0.3
stacking_neighbour_leave_out_rate=0.1
model = StackingWeightModel(
    DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
    # Or use the ExtraTreesRegressor for better predicting performance.
    # ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]), 
    distance_measure,
    kernel_type,
    neighbour_count=stacking_neighbour_count,
    neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
)
model.fit(X_plus, y, [points])
print('STST R2 Score: ', model.llocv_stacking_)

# --- Alternative ---

from sklearn.metrics import r2_score
y_predict = model.stacking_predict_
score = r2_score(y, y_predict)
print(score)


# --- Context ---

X_test, y_test, points_test = generate_sample(500, f_square, coef_strong, random_seed=2, plot=False)
X_test_plus = np.concatenate([X_test, points_test], axis=1)

y_predict = model.predict_by_fit(X_plus, y, [points], X_test_plus, [points_test])

# For weight model:
# y_predict = model.predict_by_fit(X_test_plus, [points_test])

# For predict by weight:
# y_predict = model.predict_by_weight(X_test_plus, [points_test])
score = r2_score(y_test, y_predict)

import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

from sklearn.ensemble import RandomForestRegressor
from georegression.weight_model import WeightModel

# --- Context ---

times = np.random.randint(0, 10, size=(X.shape[0], 1))
X_plus = np.concatenate([X, points, times], axis=1)

distance_measure = ["euclidean", 'euclidean']
kernel_type = ["bisquare", 'bisquare']

grf_neighbour_count = 0.3

grf_n_estimators=50
model = WeightModel(
    RandomForestRegressor(n_estimators=grf_n_estimators),
    distance_measure,
    kernel_type,
    neighbour_count=grf_neighbour_count,
)
model.fit(X_plus, y, [points, times])

import os
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestRegressor

from georegression.local_ale import weighted_ale
from georegression.simulation.simulation_utils import coefficient_wrapper
from georegression.visualize.ale import plot_ale
from georegression.weight_model import WeightModel

from georegression.simulation.simulation_for_ale import (
    coef_manual_gau,
    f_interact,
    generate_sample,
)

# Font family
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["font.weight"] = "bold"
plt.rcParams["xtick.labelsize"] = 15
plt.rcParams["ytick.labelsize"] = 15

f = f_interact
coef_func = coef_manual_gau
x2_coef = coefficient_wrapper(partial(np.multiply, 3), coef_func())


def draw_graph():
    X, y, points, f, coef = generate_sample()
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.05

    model = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        cache_data=True,
        cache_estimator=True,
    )
    model.fit(X, y, [points])
    print("GRF:", model.llocv_score_)

    feature_index = 0

    for local_index in range(model.N):
        estimator = model.local_estimator_list[local_index]
        neighbour_mask = model.neighbour_matrix_[local_index]
        neighbour_weight = model.weight_matrix_[local_index][neighbour_mask]
        X_local = model.X[neighbour_mask]
        ale_result = weighted_ale(
            X_local, feature_index, estimator.predict, neighbour_weight
        )

        fval, ale = ale_result

        x_neighbour = X[model.neighbour_matrix_[local_index], feature_index]
        y_neighbour = y[model.neighbour_matrix_[local_index]]
        weight_neighbour = model.weight_matrix_[
            local_index, model.neighbour_matrix_[local_index]
        ]

        fig = plot_ale(fval, ale, x_neighbour)
        fig.set_size_inches(10, 6)
        ax1 = fig.get_axes()[0]
        ax2 = fig.get_axes()[1]

        ax1.set_xlabel("Feature value", fontweight="bold")
        ax1.set_ylabel("Function value", fontweight="bold")
        ax2.set_ylabel("Density", fontweight="bold")

        scatter = ax1.scatter(x_neighbour, y_neighbour, c=weight_neighbour)
        ax1.scatter(
            X[local_index, feature_index], y[local_index], c="red", label="Local point"
        )
        cbar = fig.colorbar(scatter, ax=ax1, label="Weight", pad=0.1)
        cbar.set_label("Weight", weight="bold")
        cbar.ax.tick_params(labelsize=15)

        # Add legend
        handles, labels = ax1.get_legend_handles_labels()
        handles.append(scatter)
        labels.append("Weight")
        ax1.legend(handles, labels, fontsize=15)

        folder_name = "Plot/LocalAle_BigFont"
        os.makedirs(folder_name, exist_ok=True)
        plt.savefig(f"{folder_name}/{local_index}.png", dpi=300)
        plt.close()


if __name__ == "__main__":
    draw_graph()


import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

# --- Context ---

from sklearn.ensemble import RandomForestRegressor
from georegression.weight_model import WeightModel

distance_measure = "euclidean"
kernel_type = "bisquare"

grf_neighbour_count=0.3
grf_n_estimators=50
model = WeightModel(
    RandomForestRegressor(n_estimators=grf_n_estimators),
    distance_measure,
    kernel_type,
    neighbour_count=grf_neighbour_count,
)
model.fit(X_plus, y, [points])
print('STRF R2 Score: ', model.llocv_score_)

# --- Alternative ---

from sklearn.metrics import r2_score
y_predict = model.local_predict_
score = r2_score(y, y_predict)
print(score)


import numpy as np
from georegression.simulation.simulation_for_fitting import generate_sample, f_square, coef_strong

X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
X_plus = np.concatenate([X, points], axis=1)

# --- Context ---

from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import ExtraTreesRegressor
from georegression.stacking_model import StackingWeightModel

distance_measure = "euclidean"
kernel_type = "bisquare"

stacking_neighbour_count=0.3
stacking_neighbour_leave_out_rate=0.1
model = StackingWeightModel(
    DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
    # Or use the ExtraTreesRegressor for better predicting performance.
    # ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]), 
    distance_measure,
    kernel_type,
    neighbour_count=stacking_neighbour_count,
    neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
)
model.fit(X_plus, y, [points])
print('STST R2 Score: ', model.llocv_stacking_)

# --- Alternative ---

from sklearn.metrics import r2_score
y_predict = model.stacking_predict_
score = r2_score(y, y_predict)
print(score)


"""
References: https://github.com/SeldonIO/alibi/blob/master/alibi/explainers/ale.py
"""

from functools import partial
from typing import Callable, Tuple

import numpy as np


def get_quantiles(values: np.ndarray, num_quantiles: int = 11, interpolation='linear') -> np.ndarray:
    """
    Calculate quantiles of values in an array.

    Parameters
    ----------
    values
        Array of values.
    num_quantiles
        Number of quantiles to calculate.

    Returns
    -------
    Array of quantiles of the input values.

    """
    percentiles = np.linspace(0, 100, num=num_quantiles)
    quantiles = np.percentile(values, percentiles, axis=0, interpolation=interpolation)  # type: ignore[call-overload]
    return quantiles


def bisect_fun(fun: Callable, target: float, lo: int, hi: int) -> int:
    """
    Bisection algorithm for function evaluation with integer support.

    Assumes the function is non-decreasing on the interval `[lo, hi]`.
    Return an integer value v such that for all `x<v, fun(x)<target` and for all `x>=v, fun(x)>=target`.
    This is equivalent to the library function `bisect.bisect_left` but for functions defined on integers.

    Parameters
    ----------
    fun
        A function defined on integers in the range `[lo, hi]` and returning floats.
    target
        Target value to be searched for.
    lo
        Lower bound of the domain.
    hi
        Upper bound of the domain.

    Returns
    -------
    Integer index.

    """
    while lo < hi:
        mid = (lo + hi) // 2
        if fun(mid) < target:
            lo = mid + 1
        else:
            hi = mid
    return lo


def minimum_satisfied(values: np.ndarray, min_bin_points: int, n: int) -> int:
    """
    Calculates whether the partition into bins induced by `n` quantiles
    has the minimum number of points in each resulting bin.

    Parameters
    ----------
    values
        Array of feature values.
    min_bin_points
        Minimum number of points each discretized interval needs to contain.
    n
        Number of quantiles.

    Returns
    -------
    Integer encoded boolean with 1 - each bin has at least `min_bin_points` and 0 otherwise.

    """
    q = np.unique(get_quantiles(values, num_quantiles=n))
    indices = np.searchsorted(q, values, side='left')
    indices[indices == 0] = 1
    interval_n = np.bincount(indices)
    return int(np.all(interval_n[1:] > min_bin_points))


def adaptive_grid(values: np.ndarray, min_bin_points: int = 1) -> Tuple[np.ndarray, int]:
    """
    Find the optimal number of quantiles for the range of values so that each resulting bin
    contains at least `min_bin_points`. Uses bisection.

    Parameters
    ----------
    values
        Array of feature values.
    min_bin_points
        Minimum number of points each discretized interval should contain to ensure more precise
        ALE estimation.

    Returns
    -------
    q
        Unique quantiles.
    num_quantiles
        Number of non-unique quantiles the feature array was subdivided into.

    Notes
    -----
    This is a heuristic procedure since the bisection algorithm is applied
    to a function which is not monotonic. This will not necessarily find the
    maximum number of bins the interval can be subdivided into to satisfy
    the minimum number of points in each resulting bin.
    """

    # function to bisect
    def minimum_not_satisfied(values: np.ndarray, min_bin_points: int, n: int) -> int:
        """
        Logical not of `minimum_satisfied`, see function for parameter information.
        """
        return 1 - minimum_satisfied(values, min_bin_points, n)

    fun = partial(minimum_not_satisfied, values, min_bin_points)

    # bisect
    num_quantiles = bisect_fun(fun=fun, target=0.5, lo=0, hi=len(values)) - 1
    q = np.unique(get_quantiles(values, num_quantiles=num_quantiles))

    return q, num_quantiles


# TODO: https://jaykmody.com/blog/distance-matrices-with-numpy/
# TODO: https://stackoverflow.com/questions/22720864/efficiently-calculating-a-euclidean-distance-matrix-using-numpy
# TODO: Ref to https://github.com/talboger/fastdist and https://github.com/numba/numba-scipy/issues/38#issuecomment-623569703 to speed up by parallel computing
from pathlib import Path


import numpy as np
from numba import njit
from scipy.spatial.distance import pdist, cdist


def _distance_matrices(
    source_coords: list[np.ndarray], target_coords=None, metrics='euclidean',
        use_dask=False, cache_sort=False, args=None, **kwargs
):
    """
    Check the validatiton of the parameter list.
    If single value is provided, convert it to a list with length equal to the dimension of the vector list.
    Then, call the distance_matrix function for each dimension.

    Args:
        source_coords:
        target_coords:
        metrics:
        use_dask:
        args:

    Returns:

    """

    dimension = len(source_coords)

    # Check equal length of source and target coordinates
    if target_coords is not None:
        if dimension != len(target_coords):
            raise Exception("Source and target coordinate length not match")
    else:
        target_coords = [None] * dimension

    # Check whether the input parameters are lists, and if not, convert them to lists with length equal to the dimension of the vector list.
    if not isinstance(metrics, list):
        metrics = [metrics] * dimension

    if args is None:
        args = [kwargs] * dimension

    return [
        _distance_matrix(
            source_coords[dim],
            target_coords[dim],
            metrics[dim],
            use_dask,
            cache_sort,
            **args[dim],
        )
        for dim in range(dimension)
    ]


def _distance_matrix(source_coord, target_coord, metric, use_dask, cache_sort, **kwargs):
    # Check equal dimension of source and target coordinates
    if target_coord is not None:
        if source_coord.shape[1] != target_coord.shape[1]:
            raise Exception("Source and target coordinate dimension not match")

    dimension = source_coord.shape[1]

    if use_dask:
        import dask.array as da
        import dask_distance

        filepath = kwargs.get("filepath", "distance_matrix.zarr")

        if not filepath.endswith(".zarr"):
            filepath = filepath + ".zarr"

        if (kwargs.get("filepath", None) is not None) and (not kwargs.get("overwrite", False)):
            if Path(filepath).exists():
                return da.from_zarr(filepath)

        if target_coord is None:
            # TODO: Size error even after the rechunk.
            distance_matrix = dask_distance.pdist(source_coord, metric=metric)
        else:
            distance_matrix = dask_distance.cdist(source_coord, target_coord, metric=metric)

        distance_matrix = distance_matrix.rechunk({0: "auto", 1: -1})
        distance_matrix.to_zarr(filepath, overwrite=True)

        if cache_sort:
            distance_matrix_sorted = distance_matrix.map_blocks(np.sort)
            distance_matrix_sorted.to_zarr(filepath.replace(".zarr", "_sorted.zarr"), overwrite=True)

            return distance_matrix, distance_matrix_sorted

        return distance_matrix

    else:
        if metric == "great-circle":
            if dimension != 2:
                raise Exception("Great-circle distance only applicable to 2D coordinates")
            return np.array(
                [great_circle_distance(coord, target_coord) for coord in source_coord]
            ).astype(np.float32)

        if target_coord is None:
            return pdist(source_coord.astype(np.float32), metric, **kwargs).astype(np.float32)
        else:
            return cdist(
                source_coord.astype(np.float32),
                target_coord.astype(np.float32),
                metric,
                **kwargs,
            ).astype(np.float32)


@njit
def great_circle_distance(one_lonlat, many_lonlat):
    """
    Compute great-circle distance using Haversine algorithm.
    """

    lon_diff = np.radians(many_lonlat[:, 0] - one_lonlat[0])
    lat_diff = np.radians(many_lonlat[:, 1] - one_lonlat[1])
    lat_one = np.radians(one_lonlat[1])
    lat_many = np.radians(many_lonlat[:, 1])

    a = (
        np.sin(lat_diff / 2) ** 2
        + np.cos(lat_many) * np.cos(lat_one) * np.sin(lon_diff / 2) ** 2
    )
    c = 2 * np.arcsin(np.sqrt(a))
    R = 6371.0

    return R * c


import math
from typing import Union

import dask.array as da
import numpy as np

from slab_utils.quick_logger import logger

KERNEL_TYPE_ENUM = [
    "linear",
    "uniform",
    "gaussian",
    "exponential",
    "boxcar",
    "bisquare",
    "tricube",
]


def kernel_function(
    distance: np.ndarray,
    bandwidth: Union[float, list[float], np.ndarray],
    kernel_type: str,
) -> Union[np.ndarray, da.Array]:
    """
    Using kernel function to calculate the weight

    Args:
        distance: The distances to be weighted by kernel function. In vector or matrix form.

        bandwidth: parameter of the kernel function for specifying the decreasing level.
         For compact supported kernel, weight will be 0 if distance is larger than the bandwidth.

        kernel_type:

    Returns:
        np.ndarray: weight without normalization

    """

    # Reshape bandwidth to have proper broadcasting in division. It has been done in previous step.
    normalize_distance = distance / bandwidth
    if isinstance(normalize_distance, da.Array):
        # TODO: Why dask array do not support keyword argument for this method? It can be easily implemented with map_blocks
        # da.map_blocks(np.nan_to_num, normalize_distance, copy=False, nan=0.0, posinf=np.inf)
        normalize_distance = da.nan_to_num(normalize_distance, False, 0.0, np.inf)
    else:
        np.nan_to_num(normalize_distance, copy=False, nan=0.0, posinf=np.inf)

    # Continuous kernel
    if kernel_type == "uniform":
        weight = np.ones_like(normalize_distance)
    elif kernel_type == "gaussian":
        weight = np.exp(-0.5 * normalize_distance**2)
    elif kernel_type == "exponential":
        weight = np.exp(-0.5 * np.abs(normalize_distance))

    # Compact supported kernel
    elif kernel_type == "linear":
        weight = 1 - normalize_distance
    elif kernel_type == "boxcar":
        weight = np.ones_like(normalize_distance)
    elif kernel_type == "bisquare":
        # Optimize for dask array
        if isinstance(normalize_distance, da.Array):
            normalize_distance[normalize_distance > 1] = 1
            weight = (1 - normalize_distance**2) ** 2

            # It will never be negative!!!
            # weight[weight < 0] = 0

            return weight

        weight = (1 - normalize_distance**2) ** 2
    elif kernel_type == "tricube":
        weight = (1 - np.abs(normalize_distance) ** 3) ** 3
    else:
        raise Exception("Unsupported kernel")

    # compact support
    if kernel_type in ["linear", "boxcar", "bisquare", "tricube"]:
        weight[distance > bandwidth] = 0

    return weight


def adaptive_bandwidth(distance: np.ndarray, neighbour_count: Union[int, float], distance_sorted: da.array) -> float:
    """
    Find the bandwidth to include the specified number of neighbour.

    Args:
        distance: The distances to calculate the adaptive bandwidth. In vector or matrix form.
        neighbour_count: Number of the neighbour to include by the bandwidth.
         Use float to specify the percentage of the neighbour to include.

    Returns:
        float: return the distance to the K nearest neighbour
    """

    if isinstance(distance, da.Array):
        # Support for dask array
        # Duplicated coordinate is not supported
        N = distance.shape[0]

        if distance_sorted is not None:
            # Must be the sorted distance matrix
            if isinstance(neighbour_count, float):
                neighbour_count = math.ceil(N * neighbour_count)

            if neighbour_count > N:
                raise Exception("Invalid neighbour count")

            bandwidth = distance_sorted[:, [neighbour_count - 1]]
            return bandwidth
        else:
            logger.warning('Sorted distance matrix is not provided, requiring longer time to perform quantile function')

            bandwidth = distance.map_blocks(
                np.quantile,
                neighbour_count,
                axis=1,
                keepdims=True,
                drop_axis=1,
                new_axis=1,
            )
            return bandwidth

    if neighbour_count <= 0:
        raise Exception("Invalid neighbour count")

    if isinstance(neighbour_count, float):
        # percentile call the partition function internally,
        bandwidth = np.quantile(distance, neighbour_count, axis=1, keepdims=True, method='median_unbiased')
    elif isinstance(neighbour_count, int):
        bandwidth = np.partition(distance, neighbour_count - 1)[:, [neighbour_count - 1]]
    else:
        raise Exception("Invalid neighbour count")

    return bandwidth


def adaptive_kernel(
        distance: np.ndarray,
        neighbour_count: Union[int, float],
        kernel_type: str,
        distance_sorted: da.array=None
) -> np.ndarray:
    """
    Deduce the bandwidth from the neighbour count and calculate weight using kernel function.

    Args:
        distance:
        neighbour_count:
        kernel_type:

    Returns:

    """

    bandwidth = adaptive_bandwidth(distance, neighbour_count, distance_sorted)
    return kernel_function(distance, bandwidth, kernel_type)


"""
References: https://github.com/SeldonIO/alibi/blob/master/alibi/explainers/ale.py
"""

import numpy as np
import pandas as pd

from georegression.ale_utils import adaptive_grid


def weighted_ale(X, feature, predictor, weights=None, normalize=False, min_bin_points=5):
    fvals, _ = adaptive_grid(X[:, feature], min_bin_points)

    # find which interval each observation falls into
    indices = np.searchsorted(fvals, X[:, feature], side="left")
    indices[indices == 0] = 1  # put the smallest data point in the first interval
    interval_n = np.bincount(indices)  # number of points in each interval

    # predictions for the upper and lower ranges of intervals
    z_low = X.copy()
    z_high = X.copy()
    z_low[:, feature] = fvals[indices - 1]
    z_high[:, feature] = fvals[indices]
    p_low = predictor(z_low)
    p_high = predictor(z_high)

    # finite differences
    p_deltas = p_high - p_low

    # base value, which is the average prediction for the lowest interval
    base_value = np.average(p_low[indices == 1], weights=weights[indices == 1])

    # make a dataframe for averaging over intervals
    concat = np.column_stack((p_deltas, indices, weights))
    df = pd.DataFrame(concat)

    # weighted average for each interval
    avg_p_deltas = df.groupby(1).apply(lambda x: np.average(x[0], weights=x[2])).values

    # accumulate over intervals
    accum_p_deltas = np.cumsum(avg_p_deltas, axis=0)

    # pre-pend 0 for the left-most point
    zeros = np.zeros((1, 1))
    accum_p_deltas = np.insert(accum_p_deltas, 0, zeros, axis=0)

    # center
    if normalize:
        # mean effect, R's `ALEPlot` and `iml` version (approximation per interval)
        # Eq.16 from original paper "Visualizing the effects of predictor variables in black box supervised learning models"
        ale0 = (
                0.5 * (accum_p_deltas[:-1] + accum_p_deltas[1:]) * interval_n[1:]
        ).sum(axis=0)
        ale0 = ale0 / interval_n.sum()

        ale = accum_p_deltas - ale0
    else:
        ale = accum_p_deltas + base_value

    return fvals, ale


import numpy as np
from sklearn.inspection import partial_dependence


def local_partial_dependence(local_estimator, X, weight):
    """

    Args:
        local_estimator ():
        X ():
        weight ():

    Returns:

    """
    # TODO: More detailed reason to justify the weighted partial dependence.
    # Care more on the local range.
    # Unweighted points will dominate the tendency which may not be the interested one.
    # Only calculate the local ICE?
    # Better explanation in ALE: the adverse consequences of extrapolation in PD plots
    # Ref: Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models


    feature_count = X.shape[1]

    # Select X to speed up calculation
    select_mask = weight != 0
    X = X[select_mask]
    weight = weight[select_mask]

    # Partial result of each features
    feature_list = []

    for feature_index in range(feature_count):
        pdp = partial_dependence(
            local_estimator,
            X,
            [feature_index],
            kind='both'
        )

        # Must get individual partial dependence to weight the result
        # Weight: Performance Weight. The point with more weight performance better in the model.
        # So average the partial performance according to the weight.
        individual = pdp['individual'][0]
        values = pdp['values'][0]
        weight_average = np.average(individual, axis=0, weights=weight)

        # TODO: Pack the result
        feature_list.append({
            'x': values,
            'pd': weight_average
        })

    return feature_list

# TODO: Add local ICE


import numpy as np
from numba import njit, prange
from numpy import ndarray
from scipy.sparse import csr_array


def second_order_neighbour(neighbour_matrix, neighbour_leave_out=None):
    """
    Calculate second-order neighbour matrix.
    Args:
        neighbour_matrix: First-order neighbour matrix.
        neighbour_leave_out: The subset of neighbours that should be considered as first-order neighbour. If None, use neighbour_matrix.


    Returns:

    """
    if neighbour_leave_out is None:
        neighbour_leave_out = neighbour_matrix

    if isinstance(neighbour_matrix, ndarray):
        return _second_order_neighbour_dense(neighbour_matrix, neighbour_leave_out)
    elif isinstance(neighbour_matrix, csr_array):
        indices_list = _second_order_neighbour_sparse(
            neighbour_matrix.indptr,
            neighbour_matrix.indices,
            neighbour_leave_out.indptr,
            neighbour_leave_out.indices,
        )

        # Generate the indptr and indices for the sparse matrix.
        indptr = np.zeros((len(indices_list) + 1,), dtype=np.int32)
        for i in range(len(indices_list)):
            indptr[i + 1] = indptr[i] + len(indices_list[i])

        indices = np.hstack(indices_list)

        return csr_array((np.ones_like(indices), indices, indptr))

    raise ValueError("neighbour_matrix should be np.ndarray or csr_array.")


@njit(parallel=True)
def _second_order_neighbour_sparse(
    indptr, indices, indptr_leave_out, indices_leave_out
):
    N = len(indptr) - 1
    # Manually create the list with specified length to avoid parallel Mutating error.
    indices_list = [np.empty(0, dtype=np.int64)] * N
    for row_index in prange(N):
        neighbour_indices = indices_leave_out[
            indptr_leave_out[row_index] : indptr_leave_out[row_index + 1]
        ]
        second_neighbour_indices_union = np.zeros((N,))
        for neighbour_index in neighbour_indices:
            second_neighbour_indices = indices[
                indptr[neighbour_index] : indptr[neighbour_index + 1]
            ]
            for second_neighbour_index in second_neighbour_indices:
                second_neighbour_indices_union[second_neighbour_index] = True

        second_neighbour_indices_union = np.nonzero(second_neighbour_indices_union)[0]
        indices_list[row_index] = second_neighbour_indices_union

    return indices_list


@njit(parallel=True)
def _second_order_neighbour_dense(neighbour_matrix, neighbour_leave_out):
    second_order_matrix = np.empty((neighbour_matrix.shape[1], neighbour_matrix.shape[1]), dtype=np.bool_)
    for i in prange(neighbour_matrix.shape[1]):
        second_order_matrix[i] = np.sum(
            neighbour_matrix[neighbour_leave_out[i]], axis=0
        )
    return second_order_matrix


def neighbour_shrink(weight_matrix, shrink_rate, return_weight_matrix=False, inplace=True):
    if not inplace:
        weight_matrix = weight_matrix.copy()

    if isinstance(weight_matrix, np.ndarray):
        weight_matrix = _neighbour_shrink(weight_matrix, shrink_rate)
        if return_weight_matrix:
            return weight_matrix
        else:
            return weight_matrix > 0

    elif isinstance(weight_matrix, csr_array):
        weight_matrix.data = _neighbour_shrink_sparse(weight_matrix.data, weight_matrix.indptr, shrink_rate)
        weight_matrix.eliminate_zeros()
        if return_weight_matrix:
            return weight_matrix
        else:
            return weight_matrix > 0

@njit(parallel=True)
def _neighbour_shrink(weight_matrix: np.ndarray, shrink_rate=0.5):
    for i in prange(weight_matrix.shape[0]):
        neighbour_indices = np.nonzero(weight_matrix[i])[0]
        positive_value = weight_matrix[i]
        positive_value = positive_value[neighbour_indices]
        shrink_value = np.quantile(positive_value, shrink_rate)
        positive_value[positive_value < shrink_value] = 0
        # TODO: Rename j
        for j in range(len(neighbour_indices)):
            weight_matrix[i, neighbour_indices[j]] = positive_value[j]
    return weight_matrix

@njit(parallel=True)
def _neighbour_shrink_sparse(weight_matrix_data, weight_matrix_indptr, shrink_rate=0.5):
    for i in prange(len(weight_matrix_indptr) - 1):
        positive_value = weight_matrix_data[
            weight_matrix_indptr[i] : weight_matrix_indptr[i + 1]
        ]
        shrink_value = np.quantile(positive_value, shrink_rate)
        positive_value[positive_value < shrink_value] = 0
        weight_matrix_data[
            weight_matrix_indptr[i] : weight_matrix_indptr[i + 1]
        ] = positive_value
    return weight_matrix_data


def sample_neighbour(weight_matrix, sample_rate, shrink_rate=None):
    """
    # TODO: More detailed description.

    Sample neighbour from weight matrix.
    Only the sampled neighbour will be used to fit the meta model.
    Therefore, the meta model will not be used for the sampled neighbour, but the out-of-sample neighbour.
    Args:
        weight_matrix:
        sample_rate:

    Returns:

    """

    # Do the shrink first.
    if shrink_rate is not None:
        neighbour_matrix = neighbour_shrink(weight_matrix, shrink_rate, inplace=False)
    else:
        neighbour_matrix = weight_matrix > 0

    # Do not sample itself.
    if isinstance(neighbour_matrix, np.ndarray):
        np.fill_diagonal(neighbour_matrix, False)
    elif isinstance(neighbour_matrix, csr_array):
        neighbour_matrix.setdiag(False)
        neighbour_matrix.eliminate_zeros()
    else:
        raise ValueError("weight_matrix should be np.ndarray or csr_array.")

    # Get the count to sample for each row.
    neighbour_count = np.sum(neighbour_matrix, axis=1)
    neighbour_count_sampled = np.ceil(neighbour_count * sample_rate).astype(int)
    neighbour_count_sampled[neighbour_count_sampled == 0] = 1
    neighbour_count_sampled[
        neighbour_count_sampled > neighbour_count
    ] = neighbour_count[neighbour_count_sampled > neighbour_count]

    neighbour_matrix_sampled = np.zeros(neighbour_matrix.shape, dtype=bool)

    # Set fixed random seed.
    np.random.seed(0)

    if isinstance(neighbour_matrix, np.ndarray):
        for i in range(neighbour_matrix.shape[0]):
            neighbour_matrix_sampled[
                i,
                np.random.choice(
                    # nonzero [0] for 1d array; [1] for 2d array.
                    np.nonzero(neighbour_matrix[i])[0],
                    neighbour_count_sampled[i],
                    replace=False,
                ),
            ] = True
    else:
        indices_list = []
        for i in range(neighbour_matrix.shape[0]):
            indices_list.append(
                # Sort the indices to make sure the structure of sparse matrix is correct.
                # But, really need to sort?
                np.sort(
                    # Leave out itself.
                    np.append(
                        np.random.choice(
                            neighbour_matrix.indices[
                                neighbour_matrix.indptr[i] : neighbour_matrix.indptr[
                                    i + 1
                                ]
                            ],
                            neighbour_count_sampled[i],
                            replace=False,
                        ),
                        i,
                    )
                )
            )

        indptr = np.zeros((len(indices_list) + 1,), dtype=np.int32)
        for i in range(len(indices_list)):
            indptr[i + 1] = indptr[i] + len(indices_list[i])

        indices = np.hstack(indices_list)
        neighbour_matrix_sampled = csr_array(
            (np.ones_like(indices), indices, indptr), dtype=bool
        )

    # Leave out itself.
    if isinstance(neighbour_matrix_sampled, np.ndarray):
        np.fill_diagonal(neighbour_matrix_sampled, True)

    return neighbour_matrix_sampled


import numpy as np
from numba import njit, float64, boolean


@njit()
def mean(x, axis, weight):
    weight = weight.reshape((-1, 1))
    x = x * weight
    return np.sum(x, axis) / np.sum(weight)

@njit()
def ridge_cholesky(X, y, alpha, weight):
    y = y.reshape((-1, 1))

    # Center the data to make the intercept term zero
    # (n,)
    X_offset = mean(X, 0, weight)
    # (1,)
    y_offset = mean(y, 0, weight)

    # (m, n)
    X_center = X - X_offset
    # (m,)
    y_center = y - y_offset

    # sample_weight via a simple rescaling
    weight_sqrt = np.sqrt(weight)
    for index, weight in enumerate(weight_sqrt):
        X_center[index] *= weight
        y_center[index] *= weight

    # (n, n)
    A = np.dot(X_center.T, X_center)
    # (n, 1)
    Xy = np.dot(X_center.T, y_center)

    A = A + alpha * np.eye(X.shape[1])

    # (n,)
    coef = np.linalg.solve(A, Xy)
    # (1,)
    intercept = y_offset - np.dot(X_offset, coef)

    return coef, intercept

@njit()
def r2_score(y_true, y_pred):
    # https://github.com/jcatankard/NumbaML/blob/main/numbaml/scoring.py
    """https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html#sklearn.metrics.r2_score"""
    y_mean = np.mean(y_true)
    total_sum_squares = np.sum((y_true - y_mean) ** 2)
    residual_sum_squares = np.sum((y_true - y_pred) ** 2)

    if total_sum_squares == 0:
        return 1

    return 1 - (residual_sum_squares / total_sum_squares)


if __name__ == '__main__':
    X = np.random.randn(1000, 100)
    y = np.random.randn(1000)
    alpha = 10
    weight = np.random.random((1000, ))

    coef, intercept = ridge_cholesky(X, y, alpha, weight)

    print(coef, intercept)
    print(coef.shape, intercept.shape)

    print()


import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression


# TODO: Notice: Diagonal of the matrix is set to 0.
# TODO: Formular revision required.

def spatiotemporal_MI(y, w):
    """
    https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-statistics/h-how-spatial-autocorrelation-moran-s-i-spatial-st.htm
    https://pro.arcgis.com/en/pro-app/latest/tool-reference/spatial-statistics/h-global-morans-i-additional-math.htm

    Args:
        y ():
        w ():

    Returns:

    """
    w = w / np.sum(w, axis=1)

    n = y.shape[0]

    z = (y - np.mean(y)) / np.std(y)
    S0 = np.sum(w)
    I = (n / S0) * (np.matmul(np.matmul(w, z), z)) / np.sum(z ** 2)

    S1 = (1 / 2) * np.sum((w + w.T) ** 2)
    S2 = np.sum(np.sum(w + w.T, axis=1) ** 2)
    D = np.sum(z ** 4) / np.sum(z ** 2) ** 2
    A = n * ((n ** 2 - 3 * n + 3) * S1 - n * S2 + 3 * S0 ** 2)
    B = D * ((n ** 2 - n) * S1 - 2 * n * S2 + 6 * S0 ** 2)
    C = (n - 1) * (n - 2) * (n - 3) * S0 ** 2

    EI = -1 / (n - 1)
    EI2 = (A - B) / C
    VI = EI2 - EI ** 2

    return I, EI, VI


def STMI(y, spatial_temporal_weight):
    spatial_temporal_weight = spatial_temporal_weight / np.sum(spatial_temporal_weight, axis=1)

    N = y.shape[0]
    y_mean = np.mean(y)
    # Deviation from the mean without normalization by std.
    u = y - y_mean
    numerator = N * np.matmul(u, np.matmul(spatial_temporal_weight, u))
    W = np.sum(spatial_temporal_weight)
    denominator = W * np.sum(u ** 2)
    I = numerator / denominator
    expectation = -1 / (N - 1)
    S1 = (1 / 2) * np.sum((spatial_temporal_weight + spatial_temporal_weight.T) ** 2)
    S2 = np.sum(np.sum(spatial_temporal_weight + spatial_temporal_weight.T, axis=1) ** 2)
    S3 = N ** (-1) * np.sum(u ** 4) / (N ** (-1) * np.sum(u ** 2)) ** 2
    S4 = (N ** 2 - 3 * N + 3) * S1 - N * S2 + 3 * W ** 2
    S5 = (N ** 2 - N) * S1 - 2 * N * S2 + 6 * W ** 2
    variance = (N * S4 - S3 * S5) / ((N - 1) * (N - 2) * (N - 3) * W ** 2) - expectation ** 2
    p = norm.cdf((I - expectation) / np.sqrt(variance))
    return I, expectation, variance, p


def spatiotemporal_LMI(y, spatial_temporal_weight):
    spatial_temporal_weight = spatial_temporal_weight / np.sum(spatial_temporal_weight, axis=1)

    N = y.shape[0]
    y_mean = np.mean(y)
    u = y - y_mean
    m2 = np.sum(u ** 2) / N
    return (u / m2) * np.matmul(spatial_temporal_weight, u)


def plot_moran_diagram(y, weight_matrix):
    y_center = y - np.mean(y)
    y_neighbour = np.matmul(weight_matrix, y_center)
    plt.scatter(y_center, y_neighbour, s=10, edgecolors="k", alpha=0.5)
    plt.axhline(y=0, color='k', linewidth=1)  # added because i want the origin
    plt.axvline(x=0, color='k', linewidth=1)

    y_center_min = np.min(y_center)
    y_center_max = np.max(y_center)
    y_center_range = np.array([y_center_min, y_center_max]).reshape(-1, 1)

    estimator = LinearRegression(fit_intercept=False)
    estimator.fit(y_center.reshape(-1, 1), y_neighbour)
    # coef = np.polyfit(y_center, y_neighbour, 1)
    # poly1d_fn = np.poly1d(coef)
    # poly1d_fn is now a function which takes in x and returns an estimate for y
    plt.plot(y_center_range, estimator.predict(y_center_range),
             '--k')  # '--k'=black dashed line, 'yo' = yellow circle marker

    return


import math
from itertools import compress
from time import time

import numpy as np
from numba import njit, prange
from scipy.sparse import csr_array
from sklearn.base import BaseEstimator, clone
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.utils import check_X_y
from slab_utils.quick_logger import logger
from joblib import Parallel, delayed

from georegression.neighbour_utils import (
    second_order_neighbour,
    sample_neighbour,
    neighbour_shrink,
)
from georegression.numba_impl import ridge_cholesky
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel
from georegression.numba_impl import r2_score as r2_numba


def _fit_local_estimator(
        local_estimator, X, y,
        sample_weight,
        X_second_neighbour,
        local_x, return_estimator=False
):
    """
    Wrapper for parallel fitting.
    """

    # TODO: Add partial calculation for non-cache solution.

    local_estimator.fit(X, y, sample_weight=sample_weight)
    local_predict = local_estimator.predict(local_x.reshape(1, -1))
    second_neighbour_predict = local_estimator.predict(X_second_neighbour)

    if return_estimator:
        return local_predict, second_neighbour_predict, local_estimator
    else:
        return local_predict, second_neighbour_predict, None

def _fit(
    X,
    y,
    estimator_list,
    weight_matrix,
    second_neighbour_matrix,
    local_indices=None,
    cache_estimator=False,
    X_predict=None,
    n_patches=None,
):
    """
    Using joblib to parallelize the meta predicting process fails to accelerate, because the overhead of pickling/unpickling the model is too heavy.
    This is a compromise solution to incorporate the second neighbour prediction procedure into the fitting process.
    Actually, no so much work to implement this than I assumed before.
    # TODO: To better solve the prorblem, using numba or cython or other language to fully utilize the multicore.

    """

    t_start = time()

    # Generate the mask of selection from weight matrix. Select non-zero weight to avoid zero weight input.
    neighbour_matrix = weight_matrix != 0

    # Use all data if sample indices not provided.
    N = weight_matrix.shape[0]
    if local_indices is None:
        local_indices = range(N)

    # Data used for local prediction. Different from X when source and target are not same for weight matrix.
    if X_predict is None:
        X_predict = X

    # Parallel run the job. return [(prediction, estimator), (), ...]
    if isinstance(weight_matrix, np.ndarray):
        def batch_wrapper(local_indices):
            local_prediction_list = []
            second_neighbour_prediction_list = []
            local_estimator_list = []
            for i in local_indices:
                estimator = estimator_list[i]
                neighbour_mask = neighbour_matrix[i]
                row_weight = weight_matrix[i]
                x = X_predict[i]
                local_predict, second_neighbour_predict, local_estimator = _fit_local_estimator(
                    estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                    sample_weight=row_weight[neighbour_mask], X_second_neighbour=X[second_neighbour_matrix[i]],
                    return_estimator=cache_estimator
                )
                local_prediction_list.append(local_predict)
                second_neighbour_prediction_list.append(second_neighbour_predict)
                local_estimator_list.append(local_estimator)

            return local_prediction_list, second_neighbour_prediction_list, local_estimator_list

    elif isinstance(weight_matrix, csr_array):
        def batch_wrapper(local_indices):
            local_prediction_list = []
            second_neighbour_prediction_list = []
            local_estimator_list = []
            for i in local_indices:
                estimator = estimator_list[i]
                neighbour_mask = neighbour_matrix.indices[
                                 neighbour_matrix.indptr[i]:neighbour_matrix.indptr[i + 1]
                                 ]
                second_neighbour_mask = second_neighbour_matrix.indices[
                                        second_neighbour_matrix.indptr[i]:second_neighbour_matrix.indptr[i + 1]
                                ]
                row_weight = weight_matrix.data[
                             weight_matrix.indptr[i]:weight_matrix.indptr[i + 1]
                             ]
                x = X_predict[i]
                local_predict, second_neighbour_predict, local_estimator = _fit_local_estimator(
                    estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                    sample_weight=row_weight, X_second_neighbour=X[second_neighbour_mask],
                    return_estimator=cache_estimator
                )
                local_prediction_list.append(local_predict)
                second_neighbour_prediction_list.append(second_neighbour_predict)
                local_estimator_list.append(local_estimator)

            return local_prediction_list, second_neighbour_prediction_list, local_estimator_list

    # Split the local indices.
    local_indices_batch_list = np.array_split(local_indices, n_patches)
    parallel_batch_result = Parallel(n_patches)(
        delayed(batch_wrapper)(local_indices_batch) for local_indices_batch in local_indices_batch_list
    )

    local_predict = []
    second_neighbour_predict = []
    local_estimator_list = []
    for local_prediction_batch, second_neighbour_prediction_batch, local_estimator_batch in parallel_batch_result:
        local_predict.extend(local_prediction_batch)
        second_neighbour_predict.extend(second_neighbour_prediction_batch)
        local_estimator_list.extend(local_estimator_batch)

    if isinstance(weight_matrix, np.ndarray):
        X_meta_T = np.zeros((N, N))
        for i in range(N):
            X_meta_T[i, second_neighbour_matrix[i]] = second_neighbour_predict[i]

        X_meta = X_meta_T.T.copy()

    elif isinstance(weight_matrix, csr_array):
        X_meta_T = csr_array(
            (
                np.hstack(second_neighbour_predict),
                second_neighbour_matrix.indices,
                second_neighbour_matrix.indptr,
            )
        )
        X_meta = X_meta_T.getH().tocsr()


    t_end = time()
    logger.debug(f"Parallel fit time: {t_end - t_start}")

    return local_predict, X_meta, X_meta_T, local_estimator_list

class StackingWeightModel(WeightModel):
    def __init__(
        self,
        local_estimator,
        # Weight matrix param
        distance_measure=None,
        kernel_type=None,
        distance_ratio=None,
        bandwidth=None,
        neighbour_count=None,
        midpoint=None,
        distance_args=None,
        # Model param
        leave_local_out=True,
        sample_local_rate=None,
        cache_data=False,
        cache_estimator=False,
        n_jobs=None,
        n_patches=None,
        alpha=10.0,
        neighbour_leave_out_rate=None,
        neighbour_leave_out_shrink_rate=None,
        meta_fitting_shrink_rate=None,
        estimator_sample_rate=None,
        use_numba=False,
        *args,
        **kwargs
    ):
        # TODO: _fit require n_patches to be set. In the parent class, the n_patches will be set automatically if n_jobs is not set.
        super().__init__(
            local_estimator,
            distance_measure=distance_measure,
            kernel_type=kernel_type,
            distance_ratio=distance_ratio,
            bandwidth=bandwidth,
            neighbour_count=neighbour_count,
            midpoint=midpoint,
            distance_args=distance_args,
            # Model param
            leave_local_out=leave_local_out,
            sample_local_rate=sample_local_rate,
            cache_data=cache_data,
            cache_estimator=cache_estimator,
            n_jobs=n_jobs,
            n_patches=n_patches,
            *args,
            **kwargs
        )
        self.alpha = alpha
        self.neighbour_leave_out_rate = neighbour_leave_out_rate
        self.neighbour_leave_out_shrink_rate = neighbour_leave_out_shrink_rate
        self.meta_fitting_shrink_rate = meta_fitting_shrink_rate
        self.estimator_sample_rate = estimator_sample_rate
        self.use_numba = use_numba

        self.base_estimator_list = None
        self.meta_estimator_list = None

        self.stacking_predict_ = None
        self.stacking_scores_ = None

        self.llocv_stacking_ = None


    def fit(self, X, y, coordinate_vector_list=None, weight_matrix=None):
        """
        Fit an estimator at every location using the local data.
        Then, given a location, use the neighbour estimators to blending a new estimator, fitted also by the local data.

        Args:
            X:
            y:
            coordinate_vector_list ():
            weight_matrix:

        Returns:

        """
        self.log_stacking_before_fitting()

        X, y = check_X_y(X, y)
        self.is_fitted_ = True
        self.n_features_in_ = X.shape[1]
        self.N = X.shape[0]

        if coordinate_vector_list is None and weight_matrix is None:
            raise Exception('At least one of coordinate_vector_list or weight_matrix should be provided')

        # Cache data for local predict
        if self.cache_data:
            self.X = X
            self.y = y
            # TODO: Cache the weight_matrix, neighbor_matrix to make it compatible with the local diagonalization.

        cache_estimator = self.cache_estimator
        self.cache_estimator = True
        self.N = X.shape[0]

        if weight_matrix is None:
            weight_matrix = weight_matrix_from_points(
                coordinate_vector_list,
                coordinate_vector_list,
                self.distance_measure,
                self.kernel_type,
                self.distance_ratio,
                self.bandwidth,
                self.neighbour_count,
                self.distance_args,
            )

            # TODO: Tweak for inspection.
            self.weight_matrix_ = weight_matrix
            self.neighbour_matrix_ = weight_matrix > 0

        t_neighbour_process_start = time()

        # Do the leave out neighbour sampling.
        neighbour_leave_out = None
        if self.neighbour_leave_out_rate is not None:
            neighbour_leave_out = sample_neighbour(
                weight_matrix, self.neighbour_leave_out_rate, self.neighbour_leave_out_shrink_rate
            )

            if isinstance(neighbour_leave_out, csr_array):
                neighbour_leave_out_ = neighbour_leave_out

            # From (i,j) is that i-th observation will not be used to fit the j-th base estimator
            # so that the j-th base estimator will be used for meta-estimator.
            # To (j,i) is that j-th observation will not consider i-th observation as neighbour while fitting base estimator.
            if isinstance(neighbour_leave_out, np.ndarray):
                neighbour_leave_out = neighbour_leave_out.T
            else:
                # Structure not change for sparse matrix. BUG HERE.
                neighbour_leave_out = csr_array(neighbour_leave_out.T)

        # Do not change the original weight matrix to remain the original neighbour relationship.
        # Consider the phenomenon that weight_matrix_local[neighbour_leave_out.nonzero()] is not zero?
        # Because the neighbour relationship is not symmetric.
        weight_matrix_local = weight_matrix.copy()
        weight_matrix_local[neighbour_leave_out.nonzero()] = 0
        if isinstance(weight_matrix_local, csr_array):
            # To set the value for sparse matrix, convert it first to lil_array, then convert back to csr_array.
            # This can make sure the inner structure of csr_array is correct to be able to manipulate directly .
            # Or just use eliminate_zeros() to remove the zero elements.
            weight_matrix_local.eliminate_zeros()

        if self.leave_local_out:
            if isinstance(weight_matrix_local, np.ndarray):
                np.fill_diagonal(weight_matrix_local, 0)
            else:
                # TODO: High cost for sparse matrix
                weight_matrix_local.setdiag(0)
                weight_matrix_local.eliminate_zeros()

        if self.sample_local_rate is not None:
            self.local_indices_ = np.sort(np.random.choice(self.N, int(self.sample_local_rate * self.N), replace=False))
        else:
            self.local_indices_ = range(self.N)
        self.y_sample_ = y[self.local_indices_]

        t_neighbour_process_end = time()

        if isinstance(neighbour_leave_out, np.ndarray):
            avg_neighbour_count = np.count_nonzero(weight_matrix_local) / self.N
            avg_leave_out_count = np.count_nonzero(neighbour_leave_out) / self.N
        elif isinstance(neighbour_leave_out, csr_array):
            avg_neighbour_count = weight_matrix_local.count_nonzero() / self.N
            avg_leave_out_count = neighbour_leave_out.count_nonzero() / self.N

        logger.debug(
            f"End of sampling leave out neighbour and setting weight matrix for base learner: {t_neighbour_process_end - t_neighbour_process_start}\n"
            f"Average neighbour count for fitting base learner: {avg_neighbour_count}\n"
            f"Average leave out count for fitting meta learner (n): {avg_leave_out_count}"
        )

        # Just one line of addition here to implement meta_fitting_shrink_rate.
        # TODO: BUG CHECK: the weight matrix is shrinked in place? Yes. Other operation should be checked!
        if self.meta_fitting_shrink_rate is not None:
            neighbour_shrink(weight_matrix, self.meta_fitting_shrink_rate, True)
        # weight_matrix = neighbour_shrink(weight_matrix, self.meta_fitting_shrink_rate, True)

        if isinstance(weight_matrix, np.ndarray):
            avg_neighbour_count = np.count_nonzero(weight_matrix) / self.N
        elif isinstance(weight_matrix, csr_array):
            avg_neighbour_count = weight_matrix.count_nonzero() / self.N
        logger.debug(f"End of shrinking weight matrix for meta learner. Average neighbour count for fitting meta learner (m): {avg_neighbour_count}\n")

        neighbour_matrix = weight_matrix > 0

        # Indicator of input data for each local estimator.
        # Before the local itself is set False in neighbour_matrix. Avoid no meta prediction for local.
        t_second_order_start = time()
        second_neighbour_matrix = second_order_neighbour(
            neighbour_matrix, neighbour_leave_out
        )
        t_second_order_end = time()
        logger.debug(f"End of Generating Second order neighbour matrix: {t_second_order_end - t_second_order_start}")

        if isinstance(neighbour_matrix, np.ndarray):
            np.fill_diagonal(neighbour_matrix, False)
        elif isinstance(neighbour_matrix, csr_array):
            # BUG HERE. setdiag doesn't change the structure (indptr, indices), only data change from True to False.
            neighbour_matrix.setdiag(False)
            # TO FIX: Just use eliminate_zeros
            neighbour_matrix.eliminate_zeros()

        # Iterate the stacking estimator list to get the transformed X meta.
        # Cache all the data that will be used by neighbour estimators in one iteration by using second_neighbour_matrix.
        # First dimension is data index, second dimension is estimator index.
        # X_meta[i, j] means the prediction of estimator j on data i.
        t_predict_s = time()

        t_base_fit_start = time()
        local_predict, X_meta, X_meta_T, local_estimator_list = _fit(
            X,
            y,
            estimator_list=[clone(self.local_estimator) for _ in range(self.N)],
            weight_matrix=weight_matrix_local,
            second_neighbour_matrix=second_neighbour_matrix,
            cache_estimator=True,
            n_patches=self.n_patches,
        )
        t_base_fit_end = time()

        self.local_predict_ = local_predict
        self.local_estimator_list = local_estimator_list

        self.llocv_score_ = r2_score(self.y_sample_, self.local_predict_)
        self.local_residual_ = self.y_sample_ - self.local_predict_

        self.cache_estimator = cache_estimator
        self.base_estimator_list = self.local_estimator_list
        self.local_estimator_list = None

        t_predict_e = time()
        logger.debug(f"End of predicting X_meta: {t_predict_e - t_predict_s}")

        if not self.use_numba:
            local_stacking_predict = []
            local_stacking_estimator_list = []
            indexing_time = 0
            stacking_time = 0

            if isinstance(neighbour_leave_out, np.ndarray):
                for i in range(self.N):
                    # TODO: Use RidgeCV to find best alpha
                    final_estimator = Ridge(alpha=self.alpha, solver="lsqr")

                    t_indexing_start = time()

                    neighbour_sample = neighbour_matrix[[i], :]

                    if self.neighbour_leave_out_rate is not None:
                        # neighbour_sample = neighbour_leave_out[i]
                        neighbour_sample = neighbour_leave_out[:, i]
                        # neighbour_sample = neighbour_leave_out_[[i]]

                    # Sample from neighbour bool matrix to get sampled neighbour index.
                    if self.estimator_sample_rate is not None:
                        neighbour_indexes = np.nonzero(neighbour_sample[i])

                        neighbour_indexes = np.random.choice(
                            neighbour_indexes[0],
                            math.ceil(
                                neighbour_indexes[0].shape[0] * self.estimator_sample_rate
                            ),
                            replace=False,
                        )
                        # Convert back to bool matrix.
                        neighbour_sample = np.zeros_like(neighbour_matrix[i])
                        neighbour_sample[neighbour_indexes] = 1

                    X_fit = X_meta_T[neighbour_sample][:, neighbour_matrix[i]].T
                    y_fit = y[neighbour_matrix[i]]
                    t_indexing_end = time()

                    t_stacking_start = time()
                    final_estimator.fit(
                        X_fit, y_fit, sample_weight=weight_matrix[i, neighbour_matrix[i]]
                    )
                    t_stacking_end = time()

                    local_stacking_predict.append(
                        final_estimator.predict(
                            np.expand_dims(X_meta[i, neighbour_sample], 0)
                        )
                    )

                    # TODO: Unordered coef for each estimator.
                    stacking_estimator = StackingEstimator(
                        final_estimator,
                        list(compress(self.base_estimator_list, neighbour_sample)),
                    )
                    local_stacking_estimator_list.append(stacking_estimator)

                    indexing_time = indexing_time + t_indexing_end - t_indexing_start
                    stacking_time = stacking_time + t_stacking_end - t_stacking_start

                self.stacking_predict_ = np.array(local_stacking_predict).reshape(-1)
                self.llocv_stacking_ = r2_score(self.y_sample_, local_stacking_predict)
                self.local_estimator_list = local_stacking_estimator_list

            elif isinstance(neighbour_leave_out, csr_array):
                for i in range(self.N):
                    final_estimator = Ridge(alpha=self.alpha, solver='lsqr')

                    t_indexing_start = time()

                    # neighbour_sample = neighbour_leave_out[:, [i]]
                    # neighbour_sample = neighbour_leave_out_[[i]]

                    # Wrong leave out neighbour cause partial data leak.
                    # neighbour_leave_out_indices = neighbour_leave_out.indices[
                    #                               neighbour_leave_out.indptr[i]:neighbour_leave_out.indptr[i + 1]
                    #                               ]
                    neighbour_leave_out_indices = neighbour_leave_out_.indices[
                        neighbour_leave_out_.indptr[i] : neighbour_leave_out_.indptr[i + 1]
                    ]
                    neighbour_indices = neighbour_matrix.indices[
                        neighbour_matrix.indptr[i] : neighbour_matrix.indptr[i + 1]
                    ]

                    X_fit = (
                        X_meta_T[neighbour_leave_out_indices][:, neighbour_indices].toarray().T
                    )
                    y_fit = y[neighbour_indices]
                    t_indexing_end = time()

                    t_stacking_start = time()
                    final_estimator.fit(
                        X_fit, y_fit, sample_weight=weight_matrix[[i], neighbour_indices]
                    )
                    t_stacking_end = time()

                    local_stacking_predict.append(
                        final_estimator.predict(
                            np.expand_dims(X_meta[[i], neighbour_leave_out_indices], 0)
                        )
                    )

                    # TODO: Unordered coef for each estimator.
                    stacking_estimator = StackingEstimator(
                        final_estimator,
                        [
                            self.base_estimator_list[leave_out_index]
                            for leave_out_index in neighbour_leave_out_indices
                        ],
                    )
                    local_stacking_estimator_list.append(stacking_estimator)

                    indexing_time = indexing_time + t_indexing_end - t_indexing_start
                    stacking_time = stacking_time + t_stacking_end - t_stacking_start

                self.stacking_predict_ = np.array(local_stacking_predict).reshape(-1)
                self.llocv_stacking_ = r2_score(self.y_sample_, local_stacking_predict)
                self.local_estimator_list = local_stacking_estimator_list

            logger.debug(f"End of fitting meta estimator without numba. Indexing/Stacking time: {indexing_time}/{stacking_time}")

        else:
            if isinstance(weight_matrix, np.ndarray):
                raise Exception("Currently, Numba not support ndarray weight matrix.")

            @njit(parallel=True)
            def stacking_numba(
                leave_out_matrix_indptr,
                leave_out_matrix_indices,
                neighbour_matrix_indptr,
                neighbour_matrix_indices,
                X_meta_T_indptr,
                X_meta_T_indices,
                X_meta_T_data,
                y,
                weight_matrix_indptr,
                weight_matrix_indices,
                weight_matrix_data,
                alpha,
            ):
                N = len(leave_out_matrix_indptr) - 1
                coef_list = [np.empty((0, 0))] * N
                intercept_list = [np.empty(0)] * N
                y_predict_list = [np.empty(0)] * N
                score_fit_list = [.0] * N

                for i in prange(N):
                    leave_out_indices = leave_out_matrix_indices[
                        leave_out_matrix_indptr[i] : leave_out_matrix_indptr[i + 1]
                    ]
                    neighbour_indices = neighbour_matrix_indices[
                        neighbour_matrix_indptr[i] : neighbour_matrix_indptr[i + 1]
                    ]

                    # Find the index of the first element equals i
                    # for index_i in range(len(neighbour_indices)):
                    #     if neighbour_indices[index_i] == i:
                    #         break

                    # Delete self from neighbour_indices
                    # neighbour_indices = np.hstack((neighbour_indices[:index_i], neighbour_indices[index_i + 1:]))
                    neighbour_indices = neighbour_indices[neighbour_indices != i]

                    X_fit_T = np.zeros((len(leave_out_indices), len(neighbour_indices)))

                    # Needed to sort?
                    # leave_out_indices = np.sort(leave_out_indices)

                    for X_fit_row_index in range(len(leave_out_indices)):
                        neighbour_available_indices = X_meta_T_indices[
                            X_meta_T_indptr[
                                leave_out_indices[X_fit_row_index]
                            ] : X_meta_T_indptr[leave_out_indices[X_fit_row_index] + 1]
                        ]
                        current_column = 0
                        for available_iter_i in range(len(neighbour_available_indices)):
                            if (
                                neighbour_available_indices[available_iter_i]
                                in neighbour_indices
                            ):
                                X_fit_T[X_fit_row_index, current_column] = X_meta_T_data[
                                    X_meta_T_indptr[leave_out_indices[X_fit_row_index]]
                                    + available_iter_i
                                ]
                                current_column = current_column + 1

                    y_fit = y[neighbour_indices]

                    weight_indices = weight_matrix_indices[
                        weight_matrix_indptr[i] : weight_matrix_indptr[i + 1]
                    ]
                    # weight_indices = weight_indices[weight_indices != i]
                    weight_fit = weight_matrix_data[
                        weight_matrix_indptr[i] : weight_matrix_indptr[i + 1]
                    ]
                    weight_fit = weight_fit[weight_indices != i]

                    # weight_fit = np.hstack((weight_fit[:index_i], weight_fit[index_i + 1:]))

                    # TODO: If (m, n) m < n, then the matrix is not full rank, coef will be wrong.
                    coef, intercept = ridge_cholesky(X_fit_T.T, y_fit, alpha, weight_fit)

                    y_fit_predict = np.dot(X_fit_T.T, coef) + intercept
                    # TODO: Even worse, if m = 1, error will occur, the code below will be skipped in numba mode. The root cause is total_sum_squares becomes zero.
                    score_fit = r2_numba(y_fit, y_fit_predict.flatten())
                    score_fit_list[i] = score_fit

                    X_predict = np.zeros((len(leave_out_indices),))
                    for X_predict_row_index in range(len(leave_out_indices)):
                        neighbour_available_indices = X_meta_T_indices[
                            X_meta_T_indptr[
                                leave_out_indices[X_predict_row_index]
                            ] : X_meta_T_indptr[leave_out_indices[X_predict_row_index] + 1]
                        ]

                        # Find the index of the first element equals i
                        for available_iter_i in range(len(neighbour_available_indices)):
                            if neighbour_available_indices[available_iter_i] == i:
                                break

                        X_predict[X_predict_row_index] = X_meta_T_data[
                            X_meta_T_indptr[leave_out_indices[X_predict_row_index]]
                            + available_iter_i
                        ]

                    y_predict = np.dot(X_predict, coef) + intercept

                    coef_list[i] = coef.T
                    intercept_list[i] = intercept
                    y_predict_list[i] = y_predict

                return coef_list, intercept_list, y_predict_list, score_fit_list

            t_numba_start = time()
            # Different solver makes a little difference.
            coef_list, intercept_list, y_predict_list, score_fit_list = stacking_numba(
                neighbour_leave_out_.indptr,
                neighbour_leave_out_.indices,
                neighbour_matrix.indptr,
                neighbour_matrix.indices,
                X_meta_T.indptr,
                X_meta_T.indices,
                X_meta_T.data,
                y,
                weight_matrix.indptr,
                weight_matrix.indices,
                weight_matrix.data,
                self.alpha,
            )
            t_numba_end = time()
            logger.debug("Numba running time: %s \n", t_numba_end - t_numba_start)

            self.stacking_scores_ = score_fit_list
            self.stacking_predict_ = np.array(y_predict_list).reshape(-1)
            self.llocv_stacking_ = r2_score(self.y_sample_, self.stacking_predict_)

            self.local_estimator_list = []
            for i in range(self.N):
                final_estimator = Ridge(alpha=self.alpha, solver="cholesky")
                final_estimator.coef_ = coef_list[i]
                final_estimator.intercept_ = intercept_list[i]

                stacking_estimator = StackingEstimator(
                    final_estimator,
                    [
                        self.base_estimator_list[leave_out_index]
                        for leave_out_index in neighbour_leave_out_.indices[
                            neighbour_leave_out_.indptr[i] : neighbour_leave_out_.indptr[
                                i + 1
                            ]
                        ]
                    ],
                )

                self.local_estimator_list.append(stacking_estimator)

        # Summarize the fitting time in a single string.
        log_str = f"Leave local out elapsed: {t_neighbour_process_end - t_neighbour_process_start} \n" \
                    f"Base estimator fitting elapsed: {t_base_fit_end - t_base_fit_start} \n" \
                    f"Second order neighbour matrix elapsed: {t_second_order_end - t_second_order_start} \n" \
                    f"Meta estimator prediction elapsed: {t_predict_e - t_predict_s} \n"
        if self.use_numba:
            log_str += f"Numba running time: {t_numba_end - t_numba_start} \n"
        else:
            log_str += f"Indexing time: {indexing_time} \n" \
                    f"Stacking time: {stacking_time} \n"
        logger.debug(log_str)
        return self

    # TODO: Implement predict_by_fit
    def predict_by_fit(self,
                       X_train, y_train,
                       coordinate_vector_list_train,
                       X_predict, coordinate_vector_list_predict,
                       *args, **kwargs):
        """
        The difference between fit and predict_by_fit:
        - Training data is used to fit the local base learner, without leaving out neighbours.
        - Test data is considered as the left out neighbours of the training data.
        - The local base learners are used to predict the training data that is the neighbours of the test data,
          which can be seen as the stacking process, where the test data use the neighbouring base learner of training data.

        Details:
        - For the second neighbour matrix ...
        - The variables of test data is not used during the fitting. It's only used to make prediction after getting the
          final estimator.
        """

        self.log_stacking_before_fitting()

        X = X_train
        y = y_train
        X, y = check_X_y(X, y)
        self.is_fitted_ = True
        self.n_features_in_ = X.shape[1]
        self.N = X.shape[0]

        # Cache data for local predict
        if self.cache_data:
            self.X = X
            self.y = y
            # TODO: Cache the weight_matrix, neighbor_matrix to make it compatible with the local diagonalization.

        cache_estimator = self.cache_estimator
        self.cache_estimator = True
        self.N = X.shape[0]


        weight_matrix = weight_matrix_from_points(
            coordinate_vector_list_train, coordinate_vector_list_train,
            self.distance_measure, self.kernel_type, self.distance_ratio,
            self.bandwidth, self.neighbour_count, self.distance_args
        )

        # TODO: Tweak for inspection.
        self.weight_matrix_ = weight_matrix
        self.neighbour_matrix_ = weight_matrix > 0

        t_neighbour_process_start = time()

        weight_matrix_local = weight_matrix.copy()
        if self.leave_local_out:
            if isinstance(weight_matrix_local, np.ndarray):
                np.fill_diagonal(weight_matrix_local, 0)
            else:
                # TODO: High cost for sparse matrix
                weight_matrix_local.setdiag(0)
                weight_matrix_local.eliminate_zeros()

        t_neighbour_process_end = time()

        if isinstance(weight_matrix_local, np.ndarray):
            avg_neighbour_count = np.count_nonzero(weight_matrix_local) / self.N
        elif isinstance(weight_matrix_local, csr_array):
            avg_neighbour_count = weight_matrix_local.count_nonzero() / self.N

        logger.debug(
            f"End of sampling leave out neighbour and setting weight matrix for base learner: {t_neighbour_process_end - t_neighbour_process_start}\n"
            f"Average neighbour count for fitting base learner: {avg_neighbour_count}\n"
        )

        neighbour_matrix = weight_matrix > 0

        train_test_weight_matrix = weight_matrix_from_points(
            coordinate_vector_list_train, coordinate_vector_list_predict,
            self.distance_measure, self.kernel_type, self.distance_ratio,
            self.bandwidth, self.neighbour_count, self.distance_args
        )
        test_train_weight_matrix = train_test_weight_matrix.T.copy()

        train_test_neighbour_matrix = train_test_weight_matrix > 0
        test_train_neighbour_matrix = train_test_neighbour_matrix.T.copy()

        N_test = X_predict.shape[0]

        # Indicator of input data for each local estimator.
        # Before the local itself is set False in neighbour_matrix. Avoid no meta prediction for local.
        t_second_order_start = time()
        second_neighbour_matrix = second_order_neighbour(
            # TODO: Consider the effect of T operation for csr_array case.
            train_test_neighbour_matrix.T, neighbour_leave_out=train_test_neighbour_matrix
        )
        t_second_order_end = time()
        logger.debug(f"End of Generating Second order neighbour matrix: {t_second_order_end - t_second_order_start}")

        if isinstance(neighbour_matrix, np.ndarray):
            np.fill_diagonal(neighbour_matrix, False)
        elif isinstance(neighbour_matrix, csr_array):
            # BUG HERE. setdiag doesn't change the structure (indptr, indices), only data change from True to False.
            neighbour_matrix.setdiag(False)
            # TO FIX: Just use eliminate_zeros
            neighbour_matrix.eliminate_zeros()

        # Iterate the stacking estimator list to get the transformed X meta.
        # Cache all the data that will be used by neighbour estimators in one iteration by using second_neighbour_matrix.
        # First dimension is data index, second dimension is estimator index.
        # X_meta[i, j] means the prediction of estimator j on data i.
        t_predict_s = time()

        t_base_fit_start = time()
        local_predict, X_meta, X_meta_T, local_estimator_list = _fit(
            X,
            y,
            estimator_list=[clone(self.local_estimator) for _ in range(self.N)],
            weight_matrix=weight_matrix_local,
            second_neighbour_matrix=second_neighbour_matrix,
            cache_estimator=True,
            n_patches=self.n_patches,
        )
        t_base_fit_end = time()

        self.local_predict_ = local_predict
        self.local_estimator_list = local_estimator_list

        self.y_sample_ = y[range(self.N)]
        self.llocv_score_ = r2_score(self.y_sample_, self.local_predict_)
        self.local_residual_ = self.y_sample_ - self.local_predict_

        self.cache_estimator = cache_estimator
        self.base_estimator_list = self.local_estimator_list
        self.local_estimator_list = None

        t_predict_e = time()
        logger.debug(f"End of predicting X_meta: {t_predict_e - t_predict_s}")

        predictions = []

        if not self.use_numba:
            local_stacking_predict = []
            local_stacking_estimator_list = []
            indexing_time = 0
            stacking_time = 0

            if isinstance(test_train_neighbour_matrix, np.ndarray):
                for i in range(N_test):
                    # TODO: Use RidgeCV to find best alpha
                    final_estimator = Ridge(alpha=self.alpha, solver="lsqr")

                    t_indexing_start = time()

                    neighbour_sample = test_train_neighbour_matrix[i, :]

                    # X_fit = X_meta_T[neighbour_sample][:, neighbour_matrix[i]].T
                    X_fit = X_meta_T[neighbour_sample][:, test_train_neighbour_matrix[i]].T
                    y_fit = y[test_train_neighbour_matrix[i]]
                    t_indexing_end = time()

                    t_stacking_start = time()
                    final_estimator.fit(
                        X_fit, y_fit, sample_weight=test_train_weight_matrix[i, test_train_neighbour_matrix[i]]
                    )
                    t_stacking_end = time()

                    # local_stacking_predict.append(
                    #     final_estimator.predict(
                    #         np.expand_dims(X_meta[i, neighbour_sample], 0)
                    #     )
                    # )

                    # TODO: Unordered coef for each estimator.
                    stacking_estimator = StackingEstimator(
                        final_estimator,
                        list(compress(self.base_estimator_list, neighbour_sample)),
                    )
                    prediction = stacking_estimator.predict(X_predict[[i]])
                    predictions.append(prediction)

                    local_stacking_estimator_list.append(stacking_estimator)

                    indexing_time = indexing_time + t_indexing_end - t_indexing_start
                    stacking_time = stacking_time + t_stacking_end - t_stacking_start

                # self.stacking_predict_ = np.array(local_stacking_predict).reshape(-1)
                # self.llocv_stacking_ = r2_score(self.y_sample_, local_stacking_predict)
                self.local_estimator_list = local_stacking_estimator_list

                return predictions

            elif isinstance(neighbour_leave_out, csr_array):
                for i in range(self.N):
                    final_estimator = Ridge(alpha=self.alpha, solver='lsqr')

                    t_indexing_start = time()

                    # neighbour_sample = neighbour_leave_out[:, [i]]
                    # neighbour_sample = neighbour_leave_out_[[i]]

                    # Wrong leave out neighbour cause partial data leak.
                    # neighbour_leave_out_indices = neighbour_leave_out.indices[
                    #                               neighbour_leave_out.indptr[i]:neighbour_leave_out.indptr[i + 1]
                    #                               ]
                    neighbour_leave_out_indices = neighbour_leave_out_.indices[
                                                  neighbour_leave_out_.indptr[i]: neighbour_leave_out_.indptr[i + 1]
                                                  ]
                    neighbour_indices = neighbour_matrix.indices[
                                        neighbour_matrix.indptr[i]: neighbour_matrix.indptr[i + 1]
                                        ]

                    X_fit = (
                        X_meta_T[neighbour_leave_out_indices][:, neighbour_indices].toarray().T
                    )
                    y_fit = y[neighbour_indices]
                    t_indexing_end = time()

                    t_stacking_start = time()
                    final_estimator.fit(
                        X_fit, y_fit, sample_weight=weight_matrix[[i], neighbour_indices]
                    )
                    t_stacking_end = time()

                    local_stacking_predict.append(
                        final_estimator.predict(
                            np.expand_dims(X_meta[[i], neighbour_leave_out_indices], 0)
                        )
                    )

                    # TODO: Unordered coef for each estimator.
                    stacking_estimator = StackingEstimator(
                        final_estimator,
                        [
                            self.base_estimator_list[leave_out_index]
                            for leave_out_index in neighbour_leave_out_indices
                        ],
                    )
                    local_stacking_estimator_list.append(stacking_estimator)

                    indexing_time = indexing_time + t_indexing_end - t_indexing_start
                    stacking_time = stacking_time + t_stacking_end - t_stacking_start

                self.stacking_predict_ = np.array(local_stacking_predict).reshape(-1)
                self.llocv_stacking_ = r2_score(self.y_sample_, local_stacking_predict)
                self.local_estimator_list = local_stacking_estimator_list

            logger.debug(
                f"End of fitting meta estimator without numba. Indexing/Stacking time: {indexing_time}/{stacking_time}")

        else:
            if isinstance(weight_matrix, np.ndarray):
                raise Exception("Currently, Numba not support ndarray weight matrix.")

            @njit(parallel=True)
            def stacking_numba(
                    leave_out_matrix_indptr,
                    leave_out_matrix_indices,
                    neighbour_matrix_indptr,
                    neighbour_matrix_indices,
                    X_meta_T_indptr,
                    X_meta_T_indices,
                    X_meta_T_data,
                    y,
                    weight_matrix_indptr,
                    weight_matrix_indices,
                    weight_matrix_data,
                    alpha,
            ):
                N = len(leave_out_matrix_indptr) - 1
                coef_list = [np.empty((0, 0))] * N
                intercept_list = [np.empty(0)] * N
                y_predict_list = [np.empty(0)] * N
                score_fit_list = [.0] * N

                for i in prange(N):
                    leave_out_indices = leave_out_matrix_indices[
                                        leave_out_matrix_indptr[i]: leave_out_matrix_indptr[i + 1]
                                        ]
                    neighbour_indices = neighbour_matrix_indices[
                                        neighbour_matrix_indptr[i]: neighbour_matrix_indptr[i + 1]
                                        ]

                    # Find the index of the first element equals i
                    # for index_i in range(len(neighbour_indices)):
                    #     if neighbour_indices[index_i] == i:
                    #         break

                    # Delete self from neighbour_indices
                    # neighbour_indices = np.hstack((neighbour_indices[:index_i], neighbour_indices[index_i + 1:]))
                    neighbour_indices = neighbour_indices[neighbour_indices != i]

                    X_fit_T = np.zeros((len(leave_out_indices), len(neighbour_indices)))

                    # Needed to sort?
                    # leave_out_indices = np.sort(leave_out_indices)

                    for X_fit_row_index in range(len(leave_out_indices)):
                        neighbour_available_indices = X_meta_T_indices[
                                                      X_meta_T_indptr[
                                                          leave_out_indices[X_fit_row_index]
                                                      ]: X_meta_T_indptr[leave_out_indices[X_fit_row_index] + 1]
                                                      ]
                        current_column = 0
                        for available_iter_i in range(len(neighbour_available_indices)):
                            if (
                                    neighbour_available_indices[available_iter_i]
                                    in neighbour_indices
                            ):
                                X_fit_T[X_fit_row_index, current_column] = X_meta_T_data[
                                    X_meta_T_indptr[leave_out_indices[X_fit_row_index]]
                                    + available_iter_i
                                    ]
                                current_column = current_column + 1

                    y_fit = y[neighbour_indices]

                    weight_indices = weight_matrix_indices[
                                     weight_matrix_indptr[i]: weight_matrix_indptr[i + 1]
                                     ]
                    # weight_indices = weight_indices[weight_indices != i]
                    weight_fit = weight_matrix_data[
                                 weight_matrix_indptr[i]: weight_matrix_indptr[i + 1]
                                 ]
                    weight_fit = weight_fit[weight_indices != i]

                    # weight_fit = np.hstack((weight_fit[:index_i], weight_fit[index_i + 1:]))

                    # TODO: If (m, n) m < n, then the matrix is not full rank, coef will be wrong.
                    coef, intercept = ridge_cholesky(X_fit_T.T, y_fit, alpha, weight_fit)

                    y_fit_predict = np.dot(X_fit_T.T, coef) + intercept
                    # TODO: Even worse, if m = 1, error will occur, the code below will be skipped in numba mode. The root cause is total_sum_squares becomes zero.
                    score_fit = r2_numba(y_fit, y_fit_predict.flatten())
                    score_fit_list[i] = score_fit

                    X_predict = np.zeros((len(leave_out_indices),))
                    for X_predict_row_index in range(len(leave_out_indices)):
                        neighbour_available_indices = X_meta_T_indices[
                                                      X_meta_T_indptr[
                                                          leave_out_indices[X_predict_row_index]
                                                      ]: X_meta_T_indptr[leave_out_indices[X_predict_row_index] + 1]
                                                      ]

                        # Find the index of the first element equals i
                        for available_iter_i in range(len(neighbour_available_indices)):
                            if neighbour_available_indices[available_iter_i] == i:
                                break

                        X_predict[X_predict_row_index] = X_meta_T_data[
                            X_meta_T_indptr[leave_out_indices[X_predict_row_index]]
                            + available_iter_i
                            ]

                    y_predict = np.dot(X_predict, coef) + intercept

                    coef_list[i] = coef.T
                    intercept_list[i] = intercept
                    y_predict_list[i] = y_predict

                return coef_list, intercept_list, y_predict_list, score_fit_list

            t_numba_start = time()
            # Different solver makes a little difference.
            coef_list, intercept_list, y_predict_list, score_fit_list = stacking_numba(
                neighbour_leave_out_.indptr,
                neighbour_leave_out_.indices,
                neighbour_matrix.indptr,
                neighbour_matrix.indices,
                X_meta_T.indptr,
                X_meta_T.indices,
                X_meta_T.data,
                y,
                weight_matrix.indptr,
                weight_matrix.indices,
                weight_matrix.data,
                self.alpha,
            )
            t_numba_end = time()
            logger.debug("Numba running time: %s \n", t_numba_end - t_numba_start)

            self.stacking_scores_ = score_fit_list
            self.stacking_predict_ = np.array(y_predict_list).reshape(-1)
            self.llocv_stacking_ = r2_score(self.y_sample_, self.stacking_predict_)

            self.local_estimator_list = []
            for i in range(self.N):
                final_estimator = Ridge(alpha=self.alpha, solver="cholesky")
                final_estimator.coef_ = coef_list[i]
                final_estimator.intercept_ = intercept_list[i]

                stacking_estimator = StackingEstimator(
                    final_estimator,
                    [
                        self.base_estimator_list[leave_out_index]
                        for leave_out_index in neighbour_leave_out_.indices[
                                               neighbour_leave_out_.indptr[i]: neighbour_leave_out_.indptr[
                                                   i + 1
                                                   ]
                                               ]
                    ],
                )

                self.local_estimator_list.append(stacking_estimator)

        # Summarize the fitting time in a single string.
        log_str = f"Leave local out elapsed: {t_neighbour_process_end - t_neighbour_process_start} \n" \
                  f"Base estimator fitting elapsed: {t_base_fit_end - t_base_fit_start} \n" \
                  f"Second order neighbour matrix elapsed: {t_second_order_end - t_second_order_start} \n" \
                  f"Meta estimator prediction elapsed: {t_predict_e - t_predict_s} \n"
        if self.use_numba:
            log_str += f"Numba running time: {t_numba_end - t_numba_start} \n"
        else:
            log_str += f"Indexing time: {indexing_time} \n" \
                       f"Stacking time: {stacking_time} \n"
        logger.debug(log_str)
        return self

    def log_stacking_before_fitting(self):
        """
        Log the parameters about stacking before fitting.
        First, construct the log string.
        Then, log the string.
        """
        log_str = f"\nStacking Model start fitting with parameters:\n" \
                    f"alpha: {self.alpha}\n" \
                    f"neighbour_leave_out_rate: {self.neighbour_leave_out_rate}\n" \
                    f"estimator_sample_rate: {self.estimator_sample_rate}\n" \
                    f"neighbour_leave_out_shrink_rate: {self.neighbour_leave_out_shrink_rate}\n" \
                    f"meta_fitting_shrink_rate: {self.meta_fitting_shrink_rate}\n" \
                    f"use_numba: {self.use_numba}\n"

        logger.debug(log_str)
        return self


class StackingEstimator(BaseEstimator):
    def __init__(self, meta_estimator, base_estimators):
        self.meta_estimator = meta_estimator
        self.base_estimators = base_estimators

    def predict(self, X):
        X_meta = [meta_estimator.predict(X) for meta_estimator in self.base_estimators]
        X_meta = np.column_stack(X_meta)
        return self.meta_estimator.predict(X_meta)
    
    def score(self, X, y, sample_weight=None):
        """
        To make compatible with permutation_importance.
        """
        y_pred = self.predict(X)
        return r2_score(y, y_pred, sample_weight=sample_weight)
    
    def fit(X, y):
        """
        Used to avoid the exception in check_scoring of permutation_importance.
        """
        pass

from time import time
from typing import Union

import dask.array as da
import numpy as np
from dask.graph_manipulation import wait_on
from scipy import sparse
from slab_utils.quick_logger import logger

from georegression.distance_utils import _distance_matrices
from georegression.kernel import kernel_function, adaptive_kernel


def weight_matrix_from_points(
        source_coordinate_vector_list: list[np.ndarray],
        target_coordinate_vector_list: list[np.ndarray] = None,
        distance_measure: Union[str, list[str]] = None,
        kernel_type: Union[str, list[str]] = None,
        distance_ratio: Union[float, list[float]] = None,
        bandwidth: Union[float, list[float]] = None,
        neighbour_count: Union[float, list[float]] = None,
        distance_args: Union[dict, list[dict]] = None
) -> np.ndarray:
    """
    Iterate over each source-target pair to get weight matrix.
    Each row represent each source. Each column represent each target.
    The shape of the matrix is (number of source, number of target).

    Args:
        source_coordinate_vector_list:
        target_coordinate_vector_list:
        distance_measure:
        kernel_type:
        distance_ratio:
        bandwidth:
        neighbour_count:
        distance_args:

    Returns:

    """

    log_str = f"\nWeight Matrix from Points Info:\n"
    log_str += f"Coordinate Dimension: {len(source_coordinate_vector_list)}\n"
    for i, coordinate_vector in enumerate(source_coordinate_vector_list):
        log_str += f"coordinate_vector[{i}].shape: {coordinate_vector.shape}\n"
    logger.debug(log_str)

    t_distance_start = time()
    distance_matrices = _distance_matrices(
        source_coordinate_vector_list,
        target_coordinate_vector_list,
        distance_measure,
        distance_args,
    )
    t_distance_end = time()

    t_kernel_start = time()
    compound_weight_matrix = weight_matrix_from_distance(
        distance_matrices, kernel_type, distance_ratio, bandwidth, neighbour_count
    )
    t_kernel_end = time()

    logger.debug(f"Distance Time: {t_distance_end - t_distance_start}. Kernel Time: {t_kernel_end - t_kernel_start}")

    return compound_weight_matrix


def weight_matrix_from_distance(
        distance_matrices,
        kernel_type: Union[str, list[str]],
        distance_ratio: Union[float, list[float], None] = None,
        bandwidth: Union[float, list[float], None] = None,
        neighbour_count: Union[float, list[float], None] = None,
        distance_matrices_sorted: Union[np.ndarray, da.array] = None
) -> Union[np.ndarray, da.array]:
    """
    Calculate weights for each coordinate vector (e.g. location coordinate vector or time coordinate vector)
    and integrate the weights of each coordinate vector to one weight
    using some arithmetic operations (e.g. add or multiply).

    Or in the reversed order, Integrate the distances of each coordinate vector and calculate the weight.
    In this case, `distance_ratio` should be provided.

    All the parameters can provide in list form if weights are integrated instead of distance.
    Length of the lists should match the dimension(or length) of the vector list.

    Args:
        one_coordinate_vector_list:
        many_coordinate_vector_list:
        distance_measure:
        kernel_type:
        distance_ratio:
        bandwidth:
        neighbour_count:
        p:

    Returns:

    """

    log_str = f"\nWeight Matrix from Distance Info:\n"
    log_str += f"Distance Dimension: {len(distance_matrices)}\n"
    for i, distance_matrix in enumerate(distance_matrices):
        log_str += f"distance_matrix[{i}].shape: {distance_matrix.shape}\n"
    log_str += (
        f"kernel_type: {kernel_type}\n"
        f"distance_ratio: {distance_ratio}\n"
        f"bandwidth: {bandwidth}\n"
        f"neighbour_count: {neighbour_count}\n"
    )
    logger.debug(log_str)

    # Dimension of the vector list. (Len of the vector list)
    dimension = len(distance_matrices)

    # Check whether the size of distance matrices are the same.
    if len(set([distance_matrix.shape for distance_matrix in distance_matrices])) != 1:
        raise Exception("Size of distance matrices are not the same")

    # Check whether to use fixed kernel or adaptive kernel
    if bandwidth is None and neighbour_count is None:
        raise Exception(
            "At least one of bandwidth or neighbour count should be provided"
        )

    # Merge distance matrices.
    if distance_ratio is not None:
        if not isinstance(distance_ratio, list) and dimension != 2:
            raise Exception(
                "Distance ratio list must be provided for dimension larger than 2"
            )

        # TODO: Normalization step should be considered.

        # TODO: More operation, not only addition, should be considered.
        #  Like different distance measurements (replace distance_diff in distance_utils.py).
        #  Or some arithmetic operations like multiplication or division?

        if not isinstance(distance_ratio, list) and dimension == 2:
            distance_ratio = [1, distance_ratio]
        else:
            if len(distance_ratio) != dimension:
                raise Exception(
                    "Length of distance ratio list must match the dimension of the vector list"
                )

        distance_matrices_temp = []
        for dim in range(dimension):
            distance_matrices_temp[dim] = distance_matrices[dim] * distance_ratio[dim]

        distance_matrix = np.sum(distance_matrices_temp, axis=0)
        distance_matrices = [distance_matrix]

        dimension = 1

    # Then, calculate weight matrices and merge.

    # Also should check the dimension of the parameters if it is already a list?
    if not isinstance(kernel_type, list):
        kernel_type = [kernel_type] * dimension

    if not isinstance(bandwidth, list):
        bandwidth = [bandwidth] * dimension

    if not isinstance(neighbour_count, list):
        neighbour_count = [neighbour_count] * dimension

    if not isinstance(distance_matrices_sorted, list):
        distance_matrices_sorted = [distance_matrices_sorted] * dimension

    weights = []
    for dim in range(dimension):
        if isinstance(distance_matrices[0], da.Array):
            weights.append(
                # Need to wait on?
                weight_by_distance(
                    distance_matrices[dim], kernel_type[dim], bandwidth[dim], neighbour_count[dim],
                    distance_matrices_sorted[dim]
                )
            )
        else:
            weights.append(
                weight_by_distance(distance_matrices[dim], kernel_type[dim], bandwidth[dim], neighbour_count[dim], None)
            )

    weights = np.stack(weights)

    # TODO: Not only multiplication? e.g. Addition, minimum, maximum, average
    weight_matrix = np.prod(weights, axis=0)

    # Normalization
    # TODO: More normalization option. The key point is the proportion in a row?

    # default use row normalization
    row_sum = np.sum(weight_matrix, axis=1)
    # for some row with all 0 weight.
    row_sum[row_sum == 0] = 1
    # Notice the axis of division
    weight_matrix_norm = weight_matrix / np.expand_dims(row_sum, 1)

    if isinstance(weight_matrix_norm, da.Array):
        weight_matrix_norm = weight_matrix_norm.map_blocks(sparse.coo_matrix).compute()
        weight_matrix_norm = sparse.csr_array(weight_matrix_norm)

    # Stat the non-zero weight ratio
    if isinstance(weight_matrix_norm, np.ndarray):
        nonzero_count = np.count_nonzero(weight_matrix_norm)
    else:
        nonzero_count = weight_matrix_norm.nnz

    logger.debug(
        f"Non-zero weight ratio: {nonzero_count / weight_matrix_norm.size}\n"
        f"Average neighbour count: {nonzero_count / weight_matrix_norm.shape[0]}"
    )

    return weight_matrix_norm


def weight_by_distance(distance, kernel_type, bandwidth, neighbour_count, distance_sorted):
    """
    Using fixed kernel(bandwidth provided) or adaptive kernel(neighbour count provided)
    to calculate the weight based on the distance vector.

    Args:
        distance:
        kernel_type:
        bandwidth:
        neighbour_count:

    Returns:

    """

    if bandwidth is not None and neighbour_count is None:
        weight = kernel_function(distance, bandwidth, kernel_type)
    elif bandwidth is None and neighbour_count is not None:
        weight = adaptive_kernel(distance, neighbour_count, kernel_type, distance_sorted)
    else:
        raise Exception(
            "Choose bandwidth for fixed kernel or neighbour count for adaptive kernel"
        )
    return weight


from time import time

import numpy as np
import pandas as pd

from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, clone, RegressorMixin
from sklearn.inspection import permutation_importance, partial_dependence
from sklearn.inspection._partial_dependence import _grid_from_X
from sklearn.metrics import r2_score
from sklearn.utils.validation import check_X_y
from slab_utils.quick_logger import logger
from scipy.sparse import csr_array


from georegression.weight_matrix import weight_matrix_from_points
from georegression.local_ale import weighted_ale
from georegression.ale_utils import adaptive_grid


def fit_local_estimator(
        local_estimator, X, y,
        sample_weight=None, local_x=None,
        return_estimator=False):
    """
    Wrapper for parallel fitting.
    """

    # TODO: Add partial calculation for non-cache solution.

    local_estimator.fit(X, y, sample_weight=sample_weight)

    local_predict = None
    if local_x is not None:
        local_predict = local_estimator.predict(local_x.reshape(1, -1))

    if return_estimator:
        return local_predict, local_estimator
    else:
        return local_predict, None


def _fit(X, y, estimator_list, weight_matrix,
         local_indices=None, cache_estimator=False,
         X_predict=None,
         n_jobs=None, n_patches=None):
    """
    Fit the model using provided estimators and weight matrix.

    Args:
        X:
        y:
        estimator_list:
        weight_matrix:
        local_indices:
        cache_estimator:

    Returns:

    """

    if n_jobs is not None and n_patches is not None:
        raise ValueError("Cannot specify both `n_jobs` and `n_patches`")
    if n_jobs is None and n_patches is None:
        import multiprocessing
        n_patches = multiprocessing.cpu_count()

    t_start = time()

    # Generate the mask of selection from weight matrix. Select non-zero weight to avoid zero weight input.
    neighbour_matrix = weight_matrix != 0

    # Use all data if sample indices not provided.
    N = weight_matrix.shape[0]
    if local_indices is None:
        local_indices = range(N)

    # Data used for local prediction. Different from X when source and target are not same for weight matrix.
    if X_predict is None:
        X_predict = X

    # Parallel run the job. return [(prediction, estimator), (), ...]
    if isinstance(weight_matrix, np.ndarray):
        if n_jobs is not None:
            parallel_result = Parallel(n_jobs)(
                delayed(fit_local_estimator)(
                    estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                    sample_weight=row_weight[neighbour_mask],
                    return_estimator=cache_estimator
                )
                for index, estimator, neighbour_mask, row_weight, x in
                zip(local_indices, estimator_list, neighbour_matrix, weight_matrix, X_predict)
                if index in local_indices
            )
            local_predict, local_estimator_list = list(zip(*parallel_result))
        else:
            def batch_wrapper(local_indices):
                local_prediction_list = []
                local_estimator_list = []
                for i in local_indices:
                    estimator = estimator_list[i]
                    neighbour_mask = neighbour_matrix[i]
                    row_weight = weight_matrix[i]
                    x = X_predict[i]
                    local_predict, local_estimator = fit_local_estimator(
                        estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                        sample_weight=row_weight[neighbour_mask],
                        return_estimator=cache_estimator
                    )
                    local_prediction_list.append(local_predict)
                    local_estimator_list.append(local_estimator)

                return local_prediction_list, local_estimator_list

            local_indices_batch_list = np.array_split(local_indices, n_patches)
            parallel_batch_result = Parallel(n_patches)(
                delayed(batch_wrapper)(local_indices_batch) for local_indices_batch in local_indices_batch_list
            )

            local_predict = []
            local_estimator_list = []
            for local_prediction_batch, local_estimator_batch in parallel_batch_result:
                local_predict.extend(local_prediction_batch)
                local_estimator_list.extend(local_estimator_batch)
    else:
        # Make the task a list instead of a generator will speed up a little.
        # Use a temporal variable to save the index should speed up a lot.
        # Use native way to index also should speed up a lot.

        if n_jobs is not None:
            def task_wrapper():
                for i in local_indices:
                    estimator = estimator_list[i]
                    neighbour_mask = neighbour_matrix.indices[
                        neighbour_matrix.indptr[i]:neighbour_matrix.indptr[i + 1]
                    ]
                    row_weight = weight_matrix.data[
                        weight_matrix.indptr[i]:weight_matrix.indptr[i + 1]
                    ]
                    x = X_predict[i]
                    task = delayed(fit_local_estimator)(
                        estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                        sample_weight=row_weight,
                        return_estimator=cache_estimator
                    )
                    yield task
            parallel_result = Parallel(n_jobs)(task_wrapper())
            local_predict, local_estimator_list = list(zip(*parallel_result))
        else:
            def batch_wrapper(local_indices):
                local_prediction_list = []
                local_estimator_list = []
                for i in local_indices:
                    estimator = estimator_list[i]
                    neighbour_mask = neighbour_matrix.indices[
                                     neighbour_matrix.indptr[i]:neighbour_matrix.indptr[i + 1]
                                     ]
                    row_weight = weight_matrix.data[
                                 weight_matrix.indptr[i]:weight_matrix.indptr[i + 1]
                                 ]
                    x = X_predict[i]
                    local_predict, local_estimator = fit_local_estimator(
                        estimator, X[neighbour_mask], y[neighbour_mask], local_x=x,
                        sample_weight=row_weight,
                        return_estimator=cache_estimator
                    )
                    local_prediction_list.append(local_predict)
                    local_estimator_list.append(local_estimator)

                return local_prediction_list, local_estimator_list

            # Split the local indices.
            local_indices_batch_list = np.array_split(local_indices, n_patches)
            parallel_batch_result = Parallel(n_patches)(
                delayed(batch_wrapper)(local_indices_batch) for local_indices_batch in local_indices_batch_list
            )

            local_predict = []
            local_estimator_list = []
            for local_prediction_batch, local_estimator_batch in parallel_batch_result:
                local_predict.extend(local_prediction_batch)
                local_estimator_list.extend(local_estimator_batch)

        # No parallel was observed. It's found later that it is probably caused by the indexing and transferring bandwidth,
        # because it only happens when the data is large.
        # TODO: Manually split the task into batches and use parallel to speed up.

    t_end = time()
    logger.debug(f"Parallel fit time: {t_end - t_start}")

    return local_predict, local_estimator_list


def local_partial_dependence(local_estimator, X_local, weight, n_features_in_):
    """
    Wrapper for parallel partial dependence calculation
    """

    # Partial result of each feature.
    # [(x for feature1, y for feature1), (x for feature2, y for feature2), (...), ...]
    feature_list = []

    for feature_index in range(n_features_in_):
        pdp = partial_dependence(
            local_estimator,
            X_local,
            [feature_index],
            kind='both'
        )

        values = pdp['values'][0]
        individual = pdp['individual'][0]

        # Must get individual partial dependence to weight the result
        # Weight: Performance Weight. The point with more weight performance better in the model.
        # So average the partial performance according to the weight.
        weight_average = np.average(individual, axis=0, weights=weight)

        feature_list.append((values, weight_average))

    return feature_list


def local_importance(local_estimator, X_local, y_local, weight, n_repeats=5):
    """
    Wrapper for parallel local importance calculation
    """
    importance_result = permutation_importance(
        local_estimator,
        X_local,
        y_local,
        sample_weight=weight,
        n_repeats=n_repeats,
    )
        
    return importance_result.importances_mean


class WeightModel(BaseEstimator, RegressorMixin):
    """
    Inherit from sklearn BaseEstimator to support sklearn workflow, e.g. GridSearchCV.
    """

    def __init__(self,
                 local_estimator,
                 # Weight matrix param
                 distance_measure=None,
                 kernel_type=None,
                 distance_ratio=None,
                 bandwidth=None,
                 neighbour_count=None,
                 distance_args=None,
                 # Model param
                 leave_local_out=True,
                 sample_local_rate=None,
                 cache_data=False,
                 cache_estimator=False,
                 n_jobs=None, n_patches=None,
                 *args, **kwargs):

        # Parameters of the model
        self.local_estimator = local_estimator
        self.distance_measure = distance_measure
        self.kernel_type = kernel_type
        self.distance_ratio = distance_ratio
        self.bandwidth = bandwidth
        self.neighbour_count = neighbour_count
        self.distance_args = distance_args
        self.leave_local_out = leave_local_out
        self.sample_local_rate = sample_local_rate
        self.cache_data = cache_data
        self.cache_estimator = cache_estimator
        self.n_jobs = n_jobs
        if n_jobs is None and n_patches is None:
            import multiprocessing
            n_patches = multiprocessing.cpu_count()
        self.n_patches = n_patches

        # Attributes of the model
        self.is_fitted_ = None
        self.n_features_in_ = None

        self.X = None
        self.y = None
        self.N = None
        self.coordinate_vector_list = None
        self.coordinate_vector_dimension_ = None

        self.weight_matrix_ = None
        self.neighbour_matrix_ = None
        self.local_indices_ = None
        self.y_sample_ = None

        self.local_estimator_list = None
        self.local_predict_ = None
        # Leave local out cross validation score
        self.llocv_score_ = None
        self.local_residual_ = None

        # Permutation Importance
        self.permutation_score_decrease_ = None
        self.interaction_matrix_ = None
        # Partial Dependence
        self.local_partial_ = None
        self.feature_partial_ = None
        # ICE
        self.local_ice_ = None
        self.feature_ice_ = None

        self.args = args
        self.kwargs = kwargs


    def fit(self, X, y, coordinate_vector_list=None, weight_matrix=None):
        """
        Fix the model

        Args:
            X:
            y:
            coordinate_vector_list:
            weight_matrix:

        Returns:

        """
        self.log_before_fitting(X, y, coordinate_vector_list, weight_matrix)

        X, y = check_X_y(X, y)
        self.is_fitted_ = True
        self.n_features_in_ = X.shape[1]
        self.N = X.shape[0]

        if coordinate_vector_list is None and weight_matrix is None:
            raise Exception('At least one of coordinate_vector_list or weight_matrix should be provided')

        # Cache data for local predict
        if self.cache_data:
            self.X = X
            self.y = y
            self.coordinate_vector_list = coordinate_vector_list

        if weight_matrix is None:
            weight_matrix = weight_matrix_from_points(coordinate_vector_list, coordinate_vector_list, self.distance_measure, self.kernel_type, self.distance_ratio, self.bandwidth,
                                                      self.neighbour_count, self.distance_args)

        # TODO: Tweak for inspection.
        self.weight_matrix_ = weight_matrix
        # Set the diagonal value of the weight matrix to exclude the local location to get CV score
        if self.leave_local_out:
            if isinstance(self.weight_matrix_, np.ndarray):
                np.fill_diagonal(self.weight_matrix_, 0)
            else:
                # TODO: High cost for sparse matrix
                self.weight_matrix_.setdiag(0)
                self.weight_matrix_.eliminate_zeros()

        self.neighbour_matrix_ = self.weight_matrix_ > 0

        # TODO: Repeatable randomness of sklearn style.
        # Sample Procedure. Fit local models for only a part of points to reduce computation.
        if self.sample_local_rate is not None:
            self.local_indices_ = np.sort(np.random.choice(self.N, int(self.sample_local_rate * self.N), replace=False))
        else:
            self.local_indices_ = range(self.N)
        self.y_sample_ = y[self.local_indices_]

        # Clone local estimators
        estimator_list = [clone(self.local_estimator) for _ in range(self.N)]

        # Fit
        # Embed the local predict in the parallel build process to speed up.
        # Reduce the expense of resource allocation between processes.
        self.local_predict_, self.local_estimator_list = _fit(
            X, y, estimator_list, self.weight_matrix_, self.local_indices_,
            cache_estimator=self.cache_estimator, n_jobs=self.n_jobs, n_patches=self.n_patches
        )
        self.local_predict_ = np.array(self.local_predict_).squeeze()

        # Calculate the CV score and other metrics according the local result
        self.llocv_score_ = r2_score(self.y_sample_, self.local_predict_)
        self.local_residual_ = self.y_sample_ - self.local_predict_

        return self

    def predict_by_weight(self, X, coordinate_vector_list=None, weight_matrix=None,
                          search_optimal=False, y=None, *args, **kwargs):
        """
        Predict using local estimators.
        Weight the local predictions using the weight matrix which uses input X as source and fitted X as target.

        Args:
            X ():
            coordinate_vector_list ():
            weight_matrix ():
            search_optimal ():
            y ():
            *args ():
            **kwargs ():

        Returns:

        """

        if not self.cache_data or not self.cache_estimator:
            raise Exception('Prediction by weight need cache_data and cache_estimator set True')

        if coordinate_vector_list is None and weight_matrix:
            raise Exception('At least one of coordinate_vector_list or weight_matrix should be provided')

        # Parallel will slow down the process for the reason of data allocation between processes.
        local_predict = np.vstack([local_estimator.predict(X) for local_estimator in self.local_estimator_list])

        # Search the best parameters for prediction.
        if search_optimal:
            if y is None:
                raise Exception('y needs to be provided to search the best parameter for prediction')
            # set the best weight matrix for prediction.
            pass

        if weight_matrix is None:
            weight_matrix = weight_matrix_from_points(coordinate_vector_list, self.coordinate_vector_list, self.distance_measure, self.kernel_type, self.distance_ratio, self.bandwidth,
                                                      self.neighbour_count, self.distance_args)

        return np.sum(weight_matrix * local_predict.T, axis=1)

    def predict_by_fit(self, X, coordinate_vector_list=None, weight_matrix=None, *args, **kwargs):
        """
        Fit new local model for prediction data using the training data to make prediction.
        Calculate weight matrix by using predicting X as source and training X as target.

        Args:
            X ():
            coordinate_vector_list ():
            weight_matrix ():
            *args ():
            **kwargs ():

        Returns:

        """

        if not self.cache_data:
            raise Exception('Prediction by fit need cache_data set True')

        if coordinate_vector_list is None and weight_matrix:
            raise Exception('At least one of coordinate_vector_list or weight_matrix should be provided')

        if weight_matrix is None:
            # weight_matrix = weight_matrix_from_points(coordinate_vector_list, self.coordinate_vector_list,
            #                                           self.distance_measure, self.kernel_type, self.distance_ratio,
            #                                           self.bandwidth, self.neighbour_count, self.distance_args)
            weight_matrix = weight_matrix_from_points(self.coordinate_vector_list, coordinate_vector_list,
                                                      self.distance_measure, self.kernel_type, self.distance_ratio,
                                                      self.bandwidth, self.neighbour_count, self.distance_args)
            weight_matrix = weight_matrix.T.copy()

        N = X.shape[0]

        # Parallel build RandomForest
        estimator_list = [clone(self.local_estimator) for _ in range(N)]

        # Train local model using training set
        # but with weight matrix constructed from distance between prediction points and training point for prediction.
        local_predict, _ = _fit(self.X, self.y, estimator_list, weight_matrix, X_predict=X)

        return local_predict

    def importance_score_local(self, n_repeats=5, n_jobs=-1):
        """
        Calculate the importance score using permutation test (MDA, Mean decrease accuracy) for each local estimators.

        Args:
            n_repeats ():

        Returns: Shape(N, Feature). Each row represents a local estimator, each column represents a feature.

        """
        # TODO: Consider use out-of-bag sample for permutation importance to improve the ability of generalization.
        # TODO: Weight of the OOB error/score?

        if not self.cache_data or not self.cache_estimator:
            raise Exception('Importance score of local needs cache_data and cache_estimator set True')

        job_list = []
        for local_index in range(self.N):
            local_estimator = self.local_estimator_list[local_index]
            weight = self.weight_matrix_[local_index]
            neighbour_mask = self.neighbour_matrix_[local_index]

            job_list.append(
                delayed(local_importance)(
                    local_estimator, self.X[neighbour_mask], self.y[neighbour_mask], weight[neighbour_mask], n_repeats
                )
            )

        importance_matrix = Parallel(n_jobs=n_jobs)(job_list)
        importance_matrix = np.array(importance_matrix)

        return importance_matrix

    def importance_score_global(self, n_repeats=5):
        """
        Permute the X and predict the local value for each local estimator using the permuted data.

        Args:
            n_repeats ():

        Returns:
            Return average of each Feature.
            permutation_score_decrease_: Shape(Feature, n_repeats).

        """
        if not self.cache_data or not self.cache_estimator:
            raise Exception('Importance score of global needs cache_data and cache_estimator set True')

        # Feed bulk/batch instead of atomic data to local_estimator to speed up 100x.

        # Prediction after permuting. Shape(N, Feature, n_repeats)
        permute_predict = np.empty((self.N, self.n_features_in_, n_repeats))

        # Index of permutation along the N axis. Shape(N, Feature, n_repeats)
        permute_index = np.tile(np.arange(self.N).reshape((-1, 1, 1)), (1, self.n_features_in_, n_repeats))
        permute_index = np.apply_along_axis(np.random.permutation, 0, permute_index)

        # Predict for each local estimator
        for i in range(self.N):
            local_estimator = self.local_estimator_list[i]

            # Input of the estimator. Shape(Feature, Feature, n_repeats)
            x = self.X[i].reshape(-1, 1, 1)
            x = np.tile(x, (1, self.n_features_in_, n_repeats))

            # Permute the x using permute_index. Iterate by feature.
            for feature_index in range(self.n_features_in_):
                x[feature_index, feature_index, :] = self.X[permute_index[i, feature_index, :], feature_index]

            # Flatten and transpose the x for estimation input. Shape(Feature * n_repeats, Feature)
            x = np.transpose(x.reshape(self.n_features_in_, -1))

            # Predict and reshape back. Shape(Feature, n_repeats) corresponding to "Feature * n_repeats".
            permute_predict[i, :, :] = local_estimator.predict(x).reshape(self.n_features_in_, n_repeats)

        def inner_score(y_hat):
            return r2_score(self.y, y_hat)

        # The lower score means higher importance.
        score_decrease = self.llocv_score_ - np.apply_along_axis(inner_score, 0, permute_predict)

        # More trail added.
        if self.permutation_score_decrease_ is not None:
            self.permutation_score_decrease_ = np.concatenate(
                [self.permutation_score_decrease_, score_decrease], axis=1
            )
        else:
            self.permutation_score_decrease_ = score_decrease

        importance_score = np.average(self.permutation_score_decrease_, axis=1)
        return importance_score

    def interaction_score_global(self, n_repeats=5):
        """
        Interaction importance of feature-pair {i,j} = Importance of {i,j} - Importance of i - Importance of j.
        References to SHAP Interation.

        Args:
            n_repeats ():

        Returns: Interaction matrix averaged over n_repeats. Shape(Feature, Feature).

        """
        if not self.cache_data or not self.cache_estimator:
            raise Exception('Importance score of global needs cache_data and cache_estimator set True')

        # Single feature score decrease
        if self.permutation_score_decrease_ is None:
            self.importance_score_global()
        single_score_decrease = np.average(self.permutation_score_decrease_, axis=1)

        # Shape(Feature, Feature, n_repeats)
        interaction_matrix = np.empty((self.n_features_in_, self.n_features_in_, n_repeats))
        for first_feature_index in range(self.n_features_in_):
            for second_feature_index in range(first_feature_index + 1, self.n_features_in_):

                # Shape(n_repeats, )
                feature_score_decrease_list = []

                # TODO: Parallel or bulk/batch data to speed up.
                for repeat in range(n_repeats):
                    shuffling_idx = np.arange(self.N)
                    X_permuted = self.X.copy()

                    np.random.shuffle(shuffling_idx)
                    X_permuted[:, first_feature_index] = X_permuted[shuffling_idx, first_feature_index]
                    np.random.shuffle(shuffling_idx)
                    X_permuted[:, second_feature_index] = X_permuted[shuffling_idx, second_feature_index]

                    # Predict local using the permuted data
                    local_predict = [
                        local_estimator.predict(x.reshape(1, -1))
                        for local_estimator, x in
                        zip(self.local_estimator_list, X_permuted)
                    ]
                    local_predict = np.array(local_predict).squeeze()

                    score = r2_score(self.y, local_predict)
                    score_decrease = self.llocv_score_ - score
                    feature_score_decrease_list.append(score_decrease)

                feature_score_decrease = np.array(feature_score_decrease_list)
                feature_score_interaction = feature_score_decrease \
                                            - single_score_decrease[first_feature_index] \
                                            - single_score_decrease[second_feature_index]

                interaction_matrix[first_feature_index, second_feature_index] = feature_score_interaction

        interaction_matrix = interaction_matrix + interaction_matrix.T
        if self.interaction_matrix_ is not None:
            self.interaction_matrix_ = np.concatenate([self.interaction_matrix_, interaction_matrix], axis=2)
        else:
            self.interaction_matrix_ = interaction_matrix

        interaction_score = np.average(self.interaction_matrix_, axis=2)
        return interaction_score

    def partial_dependence(self):
        # TODO: More detailed reason to justify the weighted partial dependence.
        # Care more on the local range.
        # Unweighted points will dominate the tendency which may not be the interested one.

        # Partial feature list for each local estimator.
        # [feature_list for estimator1, ...2, ...]
        local_partial_list = []

        # TODO: Batch optimize for partial_dependence. Native weighted partial dependence.
        job_list = []
        for local_index in range(self.N):
            local_estimator = self.local_estimator_list[local_index]
            weight = self.weight_matrix_[local_index][self.neighbour_matrix_[local_index]]
            X_local = self.X[self.neighbour_matrix_[local_index]]

            job_list.append(
                delayed(local_partial_dependence)(local_estimator, X_local, weight, self.n_features_in_)
            )

        local_partial_list = Parallel(n_jobs=-1)(job_list)
        self.local_partial_ = np.array(local_partial_list, dtype=object)

        # Convert local based result to feature based result.
        '''
        [
            Feature1: [(x for estimator1, y for estimator1), (x for estimator2, y for estimator2), (...), ...],
            Feature2: [...],
            ...
        ]
        '''
        # self.feature_partial_list_ = list(zip(*self.local_partial_list_))
        self.feature_partial_ = self.local_partial_.transpose((1, 0, 2))

        return self.local_partial_

    def local_ICE(self):
        local_ice_list = []
        for local_index in range(self.N):
            X_local = self.X[self.neighbour_matrix_[local_index]]
            local_estimator = self.local_estimator_list[local_index]
            percentiles = (0.05, 0.95)
            grid_resolution = 100

            feature_list = []
            for feature_index in range(self.n_features_in_):
                grid, values = _grid_from_X(X_local[:, [feature_index]], percentiles, grid_resolution)
                values = values[0]
                X_individual = np.tile(self.X[local_index], (len(values), 1))
                X_individual[:, feature_index] = values
                y = local_estimator.predict(X_individual)

                feature_list.append((values, y))

            local_ice_list.append(feature_list)

        self.local_ice_ = np.array(local_ice_list, dtype=object)
        self.feature_ice_ = self.local_ice_.transpose((1, 0, 2))

        return self.local_ice_

    def local_ALE(self, feature=0):
        ale_list = []

        for local_index in range(len(self.local_estimator_list)):
            estimator = self.local_estimator_list[local_index]
            neighbour_mask = self.neighbour_matrix_[local_index]
            neighbour_weight = self.weight_matrix_[local_index][neighbour_mask]
            X_local = self.X[neighbour_mask]
            ale_result = weighted_ale(X_local, feature, estimator.predict, neighbour_weight)
            ale_list.append(ale_result)

        return ale_list


    def global_ALE(self, feature=0):
        fvals, _ = adaptive_grid(self.X[:, feature])

        # find which interval each observation falls into
        indices = np.searchsorted(fvals, self.X[:, feature], side="left")
        indices[indices == 0] = 1  # put the smallest data point in the first interval
        interval_n = np.bincount(indices)  # number of points in each interval

        # predictions for the upper and lower ranges of intervals
        z_low = self.X.copy()
        z_high = self.X.copy()
        z_low[:, feature] = fvals[indices - 1]
        z_high[:, feature] = fvals[indices]

        p_low_list = []
        p_high_list = []
        for i in range(len(self.local_estimator_list)):
            local_estimator = self.local_estimator_list[i]
            p_local = local_estimator.predict(np.vstack((z_low[[i], :], z_high[[i], :])))
            p_low_local = p_local[0]
            p_high_local = p_local[1]
            p_low_list.append(p_low_local)
            p_high_list.append(p_high_local)

        p_low = np.array(p_low_list)
        p_high = np.array(p_high_list)

        # finite differences
        p_deltas = p_high - p_low

        # make a dataframe for averaging over intervals
        concat = np.column_stack((p_deltas, indices))
        df = pd.DataFrame(concat)

        # weighted average for each interval
        avg_p_deltas = df.groupby(1).apply(lambda x: np.average(x[0])).values

        # accumulate over intervals
        accum_p_deltas = np.cumsum(avg_p_deltas, axis=0)

        # pre-pend 0 for the left-most point
        zeros = np.zeros((1, 1))
        accum_p_deltas = np.insert(accum_p_deltas, 0, zeros, axis=0)

        # mean effect, R's `ALEPlot` and `iml` version (approximation per interval)
        # Eq.16 from original paper "Visualizing the effects of predictor variables in black box supervised learning models"
        ale0 = (
                0.5 * (accum_p_deltas[:-1] + accum_p_deltas[1:]) * interval_n[1:]
        ).sum(axis=0)
        ale0 = ale0 / interval_n.sum()

        # center
        ale = accum_p_deltas - ale0

        return fvals, ale


    def log_before_fitting(self, X, y, coordinate_vector_list=None, weight_matrix=None):
        """
        Log the parameters before fitting.
        First, construct the log string.
        Then, log the string.
        """
        log_str = f"\nWeight Model start fitting with data:\n" \
                    f"X.shape: {X.shape}\n"

        log_str += f"\nWeight Model start fitting with parameters:\n" \
                    f"local_estimator: {self.local_estimator}\n" \
                    f"leave_local_out: {self.leave_local_out}\n" \
                    f"sample_local_rate: {self.sample_local_rate}\n" \
                    f"cache_data: {self.cache_data}\n" \
                    f"cache_estimator: {self.cache_estimator}\n" \
                    f"n_jobs: {self.n_jobs}\n" \
                    f"n_patches: {self.n_patches}\n" \
                    f"args: {self.args}\n" \
                    f"kwargs: {self.kwargs}\n"

        # Should not appear here? In weight matrix instead?
        # if coordinate_vector_list is not None:
        #     log_str += f"Coordinate Dimension: {len(coordinate_vector_list)}\n"
        #     for i, coordinate_vector in enumerate(coordinate_vector_list):
        #         log_str +=  f"coordinate_vector[{i}].shape: {coordinate_vector.shape}\n"
        #
        # log_str += f"\nWeight Matrix Info:\n"
        # if weight_matrix is not None:
        #     log_str += f"weight_matrix.shape: {weight_matrix.shape}\n"
        # else:
        #     log_str += f"distance_measure: {self.distance_measure}\n" \
        #             f"kernel_type: {self.kernel_type}\n" \
        #             f"distance_ratio: {self.distance_ratio}\n" \
        #             f"bandwidth: {self.bandwidth}\n" \
        #             f"neighbour_count: {self.neighbour_count}\n" \
        #             f"distance_args: {self.distance_args}\n"

        logger.debug(log_str)
        return self









import logging

from slab_utils.quick_logger import logger

# Change the format for file output
fileHandler = [handler for handler in logger.handlers if isinstance(handler, logging.FileHandler)][0]
formatter4File = logging.Formatter(
    '[%(levelname)s] - %(asctime)s:\n%(message)s',
    '%Y-%m-%d %H:%M:%S')
fileHandler.setFormatter(formatter4File)

# set LOKY_PICKLER=pickle using os
# os.environ['LOKY_PICKLER'] = 'pickle'

import matplotlib

matplotlib.use('TKAgg')


import hyperopt
from georegression.stacking_model import StackingWeightModel
import hyperopt
from hyperopt import hp

space = {
    'n_estimators': hp.choice('n_estimators', range(10, 101)),
    'max_depth': hp.choice('max_depth', range(1, 21)),
    'max_features': hp.choice('max_features', ['auto', 'sqrt', 'log2']),
    'criterion': hp.choice('criterion', ['gini', 'entropy'])
}

def objective(params):
    clf = StackingWeightModel(**params)
    return {'loss': -clf.llocv_stacking_, 'status': hyperopt.STATUS_OK}


trials = hyperopt.Trials()
best = hyperopt.fmin(
    fn=objective,
    space=space,
    algo=hyperopt.tpe.suggest,
    max_evals=100,
    trials=trials
)

print("Best Hyperparameters:", best)

"""
Generate simulated data for testing purposes.
A Bayesian Implementation of the Multiscale Geographically Weighted Regression Model with INLA
https://doi.org/10.1080/24694452.2023.2187756
"""
from functools import partial
import matplotlib.pyplot as plt
import numpy as np

from georegression.simulation.simulation_utils import (
    gaussian_coefficient,
    interaction_function,
    radial_coefficient,
    directional_coefficient,
    sine_coefficient,
    coefficient_wrapper,
    polynomial_function,
    sigmoid_function,
    sample_points,
    sample_x,
)


def f_interact(X, C, points):
    return interaction_function(C[0])(X[:, 0], X[:, 1], points) + 0


def f_square_2(X, C, points):
    return (
        polynomial_function(C[0], 2)(X[:, 0], points)
        + polynomial_function(C[1], 2)(X[:, 1], points)
        + 0
    )


def coef_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))
    coef_sin_1 = sine_coefficient(1, np.array([-1, 1]), 1)
    coef_sin_2 = sine_coefficient(1, np.array([1, 1]), 1)
    coef_sin = coefficient_wrapper(np.sum, coef_sin_1, coef_sin_2)
    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), 3)
    coef_gau_2 = gaussian_coefficient(np.array([-5, -5]), 3, amplitude=2)
    coef_gau = coefficient_wrapper(np.sum, coef_gau_1, coef_gau_2)

    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_sin, coef_gau)

    return coef_sum


def coef_manual_gau():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), [[3, 4], [4, 8]], amplitude=-1)
    coef_gau_2 = gaussian_coefficient(np.array([-2, -5]), 5, amplitude=2)
    coef_gau_3 = gaussian_coefficient(np.array([8, 3]), 10, amplitude=-1.5)
    coef_gau_4 = gaussian_coefficient(
        np.array([2, 8]), [[3, 0], [0, 15]], amplitude=0.8
    )
    coef_gau_5 = gaussian_coefficient(np.array([5, -10]), 1, amplitude=1)
    coef_gau_6 = gaussian_coefficient(np.array([-10, -10]), 15, amplitude=1.5)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau = coefficient_wrapper(
        np.sum, coef_gau_1, coef_gau_2, coef_gau_3, coef_gau_4, coef_gau_5, coef_gau_6
    )

    # coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_gau)

    return coef_sum


def coef_auto_gau_weak():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for i in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.5, 5)
        sigma2 = np.random.uniform(0.5, 5)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


f = f_interact
coef_func = coef_manual_gau
x2_coef = coefficient_wrapper(partial(np.multiply, 3), coef_func())


def generate_sample(random_seed=None, count=100, f=f, coef_func=coef_func):
    np.random.seed(random_seed)

    points = sample_points(count, bounds=[[-10, 10], [-10, 10]])

    # x1 = sample_x(count)
    x1 = sample_x(count, mean=coef_func(), bounds=(-1, 1), points=points)

    # x2 = sample_x(count)
    # x2 = sample_x(count, bounds=(0, 1))
    x2_coef = coefficient_wrapper(partial(np.multiply, 3), coef_func())
    x2 = sample_x(count, mean=x2_coef, bounds=(-2, 2), points=points)

    if isinstance(coef_func, list):
        coefficients = [func() for func in coef_func]
    else:
        coefficients = [coef_func()]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    return X, y, points, f, coefficients


def show_sample(X, y, points, coefficients, folder="Plot"):
    """
    Show X, y, points, and coefficients in multiple subplots.
    Assume dimension of points is 2, which is a plane.
    """

    dim_points = points.shape[1]
    if dim_points != 2:
        raise ValueError("Dimension of points must be 2.")

    # Calculate the number of subplots needed.
    dim_x = X.shape[1]
    dim_coef = len(coefficients)

    # Plot X. Add colorbar for each dimension.
    plt.figure()
    for i in range(dim_x):
        plt.figure()
        plt.scatter(points[:, 0], points[:, 1], c=X[:, i], cmap="Spectral")
        plt.colorbar()
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title(f"The {i+1}-th feature of X")
        plt.savefig(f"{folder}/Simulation_X_{i+1}.png")

    # Plot y using scatter and boxplot
    plt.figure()
    plt.scatter(points[:, 0], points[:, 1], c=y, cmap="Spectral")
    plt.colorbar()
    plt.title("The value of y across the plane")
    plt.savefig(f"{folder}/Simulation_y.png")

    plt.figure()
    plt.boxplot(y)
    plt.title("The distribution of y")
    plt.savefig(f"{folder}/Simulation_y_boxplot.png")

    # Plot coefficients
    for i in range(dim_coef):
        plt.figure()
        plt.scatter(
            points[:, 0], points[:, 1], c=coefficients[i](points), cmap="Spectral"
        )
        plt.colorbar()
        plt.title(f"The {i+1}-th coefficient across the plane")
        plt.savefig(f"{folder}/Simulation_Coefficients_{i}.png")


def main():
    X, y, points, f, coefficients = generate_sample(count=5000, random_seed=1)

    show_sample(X, y, points, coefficients)
    plt.show(block=True)

    # Save all figure to the local directory
    # for i in plt.get_fignums():
    # plt.figure(i)
    # plt.savefig(f"figure{i}.png")


if __name__ == "__main__":
    main()


import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from georegression.local_ale import weighted_ale

from georegression.stacking_model import StackingWeightModel
from georegression.simulation.simulation import generate_sample
from georegression.visualize.ale import plot_ale
from georegression.weight_model import WeightModel



def draw_graph():
    X, y, points, f, coef = generate_sample(count=5000, random_seed=1)
    X_plus = np.concatenate([X, points], axis=1)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.05

    # local_estimator = DecisionTreeRegressor(splitter="random", max_depth=1)
    # local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    # model = StackingWeightModel(
    #     local_estimator,
    #     distance_measure,
    #     kernel_type,
    #     neighbour_count=neighbour_count,
    #     neighbour_leave_out_rate=0.25,
    #     cache_data=True,
    #     cache_estimator=True,
    # )
    # model.fit(X, y, [points])
    # print('Stacking:', model.llocv_score_, model.llocv_stacking_)

    model = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        cache_data=True,
        cache_estimator=True,
    )
    model.fit(X, y, [points])
    print('GRF:', model.llocv_score_)

    feature_index = 0

    # ale_list = model.local_ALE(feature_index)
    for local_index in range(model.N):
        # fval, ale = ale_list[local_index]

        estimator = model.local_estimator_list[local_index]
        neighbour_mask = model.neighbour_matrix_[local_index]
        neighbour_weight = model.weight_matrix_[local_index][neighbour_mask]
        X_local = model.X[neighbour_mask]
        ale_result = weighted_ale(
            X_local, feature_index, estimator.predict, neighbour_weight)

        fval, ale = ale_result

        x_neighbour = X[model.neighbour_matrix_[local_index], feature_index]
        y_neighbour = y[model.neighbour_matrix_[local_index]]
        weight_neighbour = model.weight_matrix_[
            local_index, model.neighbour_matrix_[local_index]]
        
        # show_function_at_point(f, coef, points[local_index], ax=ax)
        # Get the true marginal effect for function f = x1 * x2.
        x_gird = np.linspace(np.min(x_neighbour), np.max(x_neighbour), 1000)
        x1 = X[local_index, 1]
        x1 = np.tile(x1, 1000)

        # X_grid = np.stack([x_gird, x1], axis=-1)
        # y_grid = f(X_grid, coef, points[local_index])

        from georegression.simulation.simulation import x2_coef
        beta = coef[0](points[local_index])
        x2_average = x2_coef(points[local_index])
        y_grid = beta * 0.5 * ((x2_average + 2) ** 2 -
                               (x2_average - 2) ** 2) * (1 / 4) * x_gird
        
        x1_base = np.empty(500)
        x1_base[:] = np.min(x_neighbour)
        x2_base = np.random.uniform(x2_average - 2, x2_average + 2, 500)
        base_value_real = estimator.predict(np.stack([x1_base, x2_base], axis=-1)).mean()

        diff = ale[0] - base_value_real
        ale = ale - diff

        fig = plot_ale(fval, ale, x_neighbour)
        ax = fig.get_axes()[0]
        scatter = ax.scatter(x_neighbour, y_neighbour, c=weight_neighbour)
        ax.scatter(X[local_index, feature_index], y[local_index], c='red', label='Local point')
        fig.colorbar(scatter, ax=ax, label='Weight')

        ax.plot(x_gird, y_grid, label="Function value")

        ale_result = weighted_ale(X_local, feature_index, estimator.predict, np.ones(X_local.shape[0]))
        fval, ale = ale_result
        diff = ale[0] - base_value_real
        ale = ale - diff
        # ax.plot(fval, ale, label="Neighbour ALE")


        # Add non-weighted ALE plot
        # ale_result = weighted_ale(X_local, feature_index, estimator.predict, np.ones(X_local.shape[0]))
        # Select the X that is in the value range of x_neighbour
        x_global_ale = X[(X[:, feature_index] >= np.min(x_neighbour)) & (X[:, feature_index] <= np.max(x_neighbour))]
        ale_result = weighted_ale(
            x_global_ale, feature_index, estimator.predict, np.ones(x_global_ale.shape[0]))
        fval, ale = ale_result

        x1_base = np.empty(500)
        x1_base[:] = np.min(x_neighbour)
        x2_base = np.random.choice(X[:, 1], 500)
        base_value_real = estimator.predict(np.stack([x1_base, x2_base], axis=-1)).mean()

        diff = ale[0] - base_value_real
        ale = ale - diff

        ax.plot(fval, ale, label="Non-weighted ALE")        

        # Add legend
        handles, labels = ax.get_legend_handles_labels()
        handles.append(scatter)
        labels.append('Weight')
        ax.legend(handles, labels)

        plt.show(block=True)


    importance_global = model.importance_score_global()
    print(importance_global)

    importance_local = model.importance_score_local()
    print(importance_local)

    # Normalize the local importance to [0, 1]
    importance_local = (importance_local - importance_local.min()) / \
        (importance_local.max() - importance_local.min())

    # Plot the local importance
    fig, ax = plt.subplots()
    scatter = ax.scatter(points[:, 0], points[:, 1],
                         c=importance_local[:, 0], cmap='viridis')
    fig.colorbar(scatter)
    plt.show()

    # Show the residual across the space.
    residual = model.stacking_predict_ - model.y_sample_
    residual = np.abs(residual)
    fig, ax = plt.subplots()
    # Lower residual values has lower transparency
    scatter = ax.scatter(points[:, 0], points[:, 1],
                         c=residual, alpha=residual / residual.max())
    fig.colorbar(scatter)
    fig.savefig('residual.png')

    fval, ale = model.global_ALE(feature_index)
    fig = plot_ale(fval, ale, X[:, feature_index])
    fig.savefig('ale_global.png')


if __name__ == "__main__":
    draw_graph()


import os
from functools import partial

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor

from georegression.local_ale import weighted_ale
from georegression.simulation.simulation_utils import *
from georegression.visualize.ale import plot_ale
from georegression.weight_model import WeightModel

# Font family
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["axes.labelsize"] = 18
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['xtick.labelsize'] = 15
plt.rcParams['ytick.labelsize'] = 15

def coef_manual_gau():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), [[3, 4], [4, 8]], amplitude=-1)
    coef_gau_2 = gaussian_coefficient(np.array([-2, -5]), 5, amplitude=2)
    coef_gau_3 = gaussian_coefficient(np.array([8, 3]), 10, amplitude=-1.5)
    coef_gau_4 = gaussian_coefficient(
        np.array([2, 8]), [[3, 0], [0, 15]], amplitude=0.8
    )
    coef_gau_5 = gaussian_coefficient(np.array([5, -10]), 1, amplitude=1)
    coef_gau_6 = gaussian_coefficient(np.array([-10, -10]), 15, amplitude=1.5)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau = coefficient_wrapper(
        np.sum, coef_gau_1, coef_gau_2, coef_gau_3, coef_gau_4, coef_gau_5, coef_gau_6
    )

    # coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_gau)

    return coef_sum

def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum

def f_interact(X, C, points):
    return interaction_function(C[0])(X[:, 0], X[:, 1], points) + 0


f = f_interact
coef_func = coef_manual_gau
x2_coef = coefficient_wrapper(partial(np.multiply, 3), coef_func())


def generate_sample(random_seed=1):
    np.random.seed(random_seed)

    count = 5000
    points = sample_points(count, bounds=[[-10, 10], [-10, 10]])
    x1 = sample_x(count, mean=coef_func(), bounds=(-1, 1), points=points)
    x2_coef = coefficient_wrapper(partial(np.multiply, 3), coef_func())
    x2 = sample_x(count, mean=x2_coef, bounds=(-2, 2), points=points)

    if isinstance(coef_func, list):
        coefficients = [func() for func in coef_func]
    else:
        coefficients = [coef_func()]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    return X, y, points, f, coefficients


def draw_graph():
    X, y, points, f, coef = generate_sample()
    X_plus = np.concatenate([X, points], axis=1)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.05

    # local_estimator = DecisionTreeRegressor(splitter="random", max_depth=1)
    # local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    # model = StackingWeightModel(
    #     local_estimator,
    #     distance_measure,
    #     kernel_type,
    #     neighbour_count=neighbour_count,
    #     neighbour_leave_out_rate=0.25,
    #     cache_data=True,
    #     cache_estimator=True,
    # )
    # model.fit(X, y, [points])
    # print('Stacking:', model.llocv_score_, model.llocv_stacking_)

    model = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        cache_data=True,
        cache_estimator=True,
    )
    model.fit(X, y, [points])
    print("GRF:", model.llocv_score_)

    feature_index = 0

    for local_index in range(model.N):
        estimator = model.local_estimator_list[local_index]
        neighbour_mask = model.neighbour_matrix_[local_index]
        neighbour_weight = model.weight_matrix_[local_index][neighbour_mask]
        X_local = model.X[neighbour_mask]
        ale_result = weighted_ale(
            X_local, feature_index, estimator.predict, neighbour_weight
        )

        fval, ale = ale_result

        x_neighbour = X[model.neighbour_matrix_[local_index], feature_index]
        y_neighbour = y[model.neighbour_matrix_[local_index]]
        weight_neighbour = model.weight_matrix_[
            local_index, model.neighbour_matrix_[local_index]
        ]

        # show_function_at_point(f, coef, points[local_index], ax=ax)
        # Get the true marginal effect for function f = x1 * x2.
        x_gird = np.linspace(np.min(x_neighbour), np.max(x_neighbour), 1000)
        x1 = X[local_index, 1]
        x1 = np.tile(x1, 1000)

        # X_grid = np.stack([x_gird, x1], axis=-1)
        # y_grid = f(X_grid, coef, points[local_index])

        from georegression.simulation.simulation import x2_coef

        beta = coef[0](points[local_index])
        x2_average = x2_coef(points[local_index])
        y_grid = (
            beta
            * 0.5
            * ((x2_average + 2) ** 2 - (x2_average - 2) ** 2)
            * (1 / 4)
            * x_gird
        )

        x1_base = np.empty(500)
        x1_base[:] = np.min(x_neighbour)
        x2_base = np.random.uniform(x2_average - 2, x2_average + 2, 500)
        base_value_real = estimator.predict(
            np.stack([x1_base, x2_base], axis=-1)
        ).mean()

        diff = ale[0] - base_value_real
        ale = ale - diff

        fig = plot_ale(fval, ale, x_neighbour)
        fig.set_size_inches(10, 6)
        ax1 = fig.get_axes()[0]
        ax2 = fig.get_axes()[1]

        ax1.set_xlabel("Feature value", fontweight='bold')
        ax1.set_ylabel("Function value", fontweight='bold')
        ax2.set_ylabel('Density', fontweight='bold')

        scatter = ax1.scatter(x_neighbour, y_neighbour, c=weight_neighbour)
        ax1.scatter(
            X[local_index, feature_index], y[local_index], c="red", label="Local point"
        )
        cbar = fig.colorbar(scatter, ax=ax1, label="Weight", pad=0.1)
        cbar.set_label('Weight', weight='bold')
        cbar.ax.tick_params(labelsize=15)

        ax1.plot(x_gird, y_grid, label="True value")

        # Neighbor ALE, which only consider the neighbor points but not weight is considered
        # ale_result = weighted_ale(
        #     X_local, feature_index, estimator.predict, np.ones(X_local.shape[0])
        # )
        # fval, ale = ale_result
        # diff = ale[0] - base_value_real
        # ale = ale - diff
        # ax.plot(fval, ale, label="Neighbour ALE")

        # Add non-weighted ALE plot
        # Select the X that is in the value range of x_neighbour
        x_global_ale = X[
            (X[:, feature_index] >= np.min(x_neighbour))
            & (X[:, feature_index] <= np.max(x_neighbour))
        ]
        ale_result = weighted_ale(
            x_global_ale,
            feature_index,
            estimator.predict,
            np.ones(x_global_ale.shape[0]),
        )
        fval, ale = ale_result

        x1_base = np.empty(500)
        x1_base[:] = np.min(x_neighbour)
        x2_base = np.random.choice(X[:, 1], 500)
        base_value_real = estimator.predict(
            np.stack([x1_base, x2_base], axis=-1)
        ).mean()

        diff = ale[0] - base_value_real
        ale = ale - diff

        ax1.plot(fval, ale, label="ALE")

        # ALE True value
        # y_grid_ale = (x2_coef(points[local_index])* x_gird)
        # ax.plot(x_gird, y_grid_ale, label="True ALE")

        # Add legend
        handles, labels = ax1.get_legend_handles_labels()
        handles.append(scatter)
        labels.append("Weight")
        ax1.legend(handles, labels, fontsize=15)

        folder_name = "Plot/LocalAle_BigFont"
        os.makedirs(folder_name, exist_ok=True)
        plt.savefig(f"{folder_name}/{local_index}.png", dpi=300)
        plt.close()
        # plt.show(block=True)


if __name__ == "__main__":
    draw_graph()


import json
import os
import time
from functools import partial

from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneOut
from sklearn.tree import DecisionTreeRegressor
from xgboost import XGBRegressor

from georegression.simulation.simulation import show_sample
from georegression.simulation.simulation_utils import *
from georegression.stacking_model import StackingWeightModel
from georegression.weight_model import WeightModel


# TODO: Explain why the improvement becomes significant when there are more points.


def fit_models(
    X,
    y,
    points,
    stacking_neighbour_count=0.03,
    stacking_neighbour_leave_out_rate=0.15,
    grf_neighbour_count=0.03,
    grf_n_estimators=50,
    gwr_neighbour_count=0.03,
    rf_n_estimators=2000,
    info=None,
):
    if info is None:
        info = {}

    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = {}

    model = StackingWeightModel(
        DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )


    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)
    result[f"Stacking_Base"] = model.llocv_score_
    result[f"Stacking"] = model.llocv_stacking_
    result[f"Stacking_Time"] = t2 - t1


    model = StackingWeightModel(
        ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )


    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()

    print("Stacking_Extra:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)

    result[f"Stacking_Extra_Base"] = model.llocv_score_
    result[f"Stacking_Extra_"] = model.llocv_stacking_
    result[f"Stacking_Extra_Time"] = t2 - t1

    model = WeightModel(
        RandomForestRegressor(n_estimators=grf_n_estimators),
        distance_measure,
        kernel_type,
        neighbour_count=grf_neighbour_count,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GRF:", model.llocv_score_)
    print(t2 - t1)
    result["GRF"] = model.llocv_score_
    result["GRF_Time"] = t2 - t1

    model = WeightModel(
        LinearRegression(),
        distance_measure,
        kernel_type,
        neighbour_count=gwr_neighbour_count,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GWR:", model.llocv_score_)
    print(t2 - t1)
    result["GWR"] = model.llocv_score_
    result["GWR_Time"] = t2 - t1

    model = RandomForestRegressor(
        oob_score=True, n_estimators=rf_n_estimators, n_jobs=-1
    )
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("RF:", model.oob_score_)
    print(t2 - t1)
    result["RF"] = model.oob_score_
    result["RF_Time"] = t2 - t1

    model = LinearRegression()
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("LR:", model.score(X_plus, y))
    print(t2 - t1)
    result["LR"] = model.score(X_plus, y)
    result["LR_Time"] = t2 - t1

    result = {**result, **info}
    with open("simulation_data_size.jsonl", "a") as f:
        f.write(json.dumps(result) + "\n")

    return result


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def f_square(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + 0


def generate_sample(count, f, coef_func, random_seed=1, plot=False):
    np.random.seed(random_seed)
    points = sample_points(count, bounds=(-10, 10))
    x1 = sample_x(count, bounds=(-10, 10))
    coefficients = [coef_func()]

    X = np.stack((x1,), axis=-1)
    y = f(X, coefficients, points)

    if plot:
        folder = f"Plot/{coef_func.__name__}_{f.__name__}_{count}"
        os.makedirs(folder, exist_ok=True)
        show_sample(X, y, points, coefficients, folder)

    return X, y, points


def square_gau_strong_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    for feature_size in (1, 5, 25, 50):
        X_expanded = np.concatenate([X, np.random.normal(0, 1, (X.shape[0], feature_size))], axis=1)

        fit_models(
            X_expanded,
            y,
            points,
            stacking_neighbour_count=0.45,
            stacking_neighbour_leave_out_rate=0.2,
            grf_neighbour_count=0.45,
            grf_n_estimators=50,
            gwr_neighbour_count=0.5,
            rf_n_estimators=2000,
            info={"f": "f_square", "coef": "coef_gau_strong", "count": 100, "feature_size": feature_size},
        )


def square_gau_strong_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    for feature_size in (1, 5, 25, 50):
        X_expanded = np.concatenate([X, np.random.normal(0, 1, (X.shape[0], feature_size))], axis=1)

        fit_models(
            X_expanded,
            y,
            points,
            stacking_neighbour_count=0.08,
            stacking_neighbour_leave_out_rate=0.1,
            grf_neighbour_count=0.1,
            grf_n_estimators=50,
            gwr_neighbour_count=0.1,
            rf_n_estimators=2000,
            info={"f": "f_square", "coef": "coef_gau_strong", "count": 500, "feature_size": feature_size},
        )


def square_gau_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    for feature_size in (1, 5, 25, 50):
        X_expanded = np.concatenate([X, np.random.normal(0, 1, (X.shape[0], feature_size))], axis=1)

        fit_models(
            X_expanded,
            y,
            points,
            stacking_neighbour_count=0.02,
            stacking_neighbour_leave_out_rate=0.05,
            grf_neighbour_count=0.01,
            grf_n_estimators=50,
            gwr_neighbour_count=0.04,
            rf_n_estimators=2000,
            info={"f": "f_square", "coef": "coef_gau_strong", "count": 1000, "feature_size": feature_size},
        )


def square_gau_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    for feature_size in (1, 5, 25, 50):
        X_expanded = np.concatenate([X, np.random.normal(0, 1, (X.shape[0], feature_size))], axis=1)

        fit_models(
            X_expanded,
            y,
            points,
            stacking_neighbour_count=0.008,
            stacking_neighbour_leave_out_rate=0.2,
            grf_neighbour_count=0.01,
            grf_n_estimators=50,
            gwr_neighbour_count=0.01,
            rf_n_estimators=2000,
            info={"f": "f_square", "coef": "coef_gau_strong", "count": 5000, "feature_size": feature_size},
        )


def square_gau_strong_10000():
    X, y, points = generate_sample(
        10000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    for feature_size in (1, 5, 25, 50):
        X_expanded = np.concatenate([X, np.random.normal(0, 1, (X.shape[0], feature_size))], axis=1)

        fit_models(
            X_expanded,
            y,
            points,
            stacking_neighbour_count=0.008,
            stacking_neighbour_leave_out_rate=0.2,
            grf_neighbour_count=0.005,
            grf_n_estimators=50,
            gwr_neighbour_count=0.008,
            rf_n_estimators=2000,
            info={"f": "f_square", "coef": "coef_gau_strong", "count": 10000, "feature_size": feature_size},
        )


def test_models(
    X,
    y,
    points,
    stacking_neighbour_count,
    stacking_neighbour_leave_out_rate,
    grf_neighbour_count,
    gwr_neighbour_count,
    count,
    func,
    coef,
):
    stacking_params = test_stacking(
        X, y, points, stacking_neighbour_count, stacking_neighbour_leave_out_rate
    )
    grf_params = test_GRF(X, y, points, grf_neighbour_count)
    gwr_params = test_GWR(X, y, points, gwr_neighbour_count)

    with open("simulation_params.jsonl", "a") as f:
        for params in stacking_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in grf_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in gwr_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")

    # Print the param with the best score
    print(max(stacking_params, key=lambda x: x["Stacking"]))
    print(max(grf_params, key=lambda x: x["GRF"]))
    print(max(gwr_params, key=lambda x: x["GWR"]))

    # Output the best result to jsonl
    with open("simulation_param_best.jsonl", "a") as f:
        f.write(json.dumps(max(stacking_params, key=lambda x: x["Stacking"])) + "\n")
        f.write(json.dumps(max(grf_params, key=lambda x: x["GRF"])) + "\n")
        f.write(json.dumps(max(gwr_params, key=lambda x: x["GWR"])) + "\n")


def test_GRF(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                RandomForestRegressor(n_estimators=50),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X_plus, y, [points])
            print("GRF:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GRF": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


def test_stacking(X, y, points, neighbour_counts, leave_out_rates):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=X.shape[1])

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            for leave_out_rate in leave_out_rates:
                model = StackingWeightModel(
                    local_estimator,
                    distance_measure,
                    kernel_type,
                    neighbour_count=neighbour_count,
                    neighbour_leave_out_rate=leave_out_rate,
                )
                if use_x_plus:
                    model.fit(X_plus, y, [points])
                else:
                    model.fit(X, y, [points])
                print(
                    "Stacking:",
                    model.llocv_score_,
                    model.llocv_stacking_,
                    "neighbour_count:",
                    neighbour_count,
                    "leave_out_rate:",
                    leave_out_rate,
                    "use_x_plus:",
                    use_x_plus,
                )
                result.append(
                    {
                        "Stacking_Base": model.llocv_score_,
                        "Stacking": model.llocv_stacking_,
                        "neighbour_count": neighbour_count,
                        "leave_out_rate": leave_out_rate,
                        "use_x_plus": use_x_plus,
                    }
                )

    return result


def test_GWR(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                LinearRegression(),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X_plus, y, [points])
            print("GWR:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GWR": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


if __name__ == "__main__":
    square_gau_strong_100()
    square_gau_strong_500()
    square_gau_strong_1000()
    square_gau_strong_5000()
    square_gau_strong_10000()

    pass


import json
import os
import time
from functools import partial

from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneOut
from sklearn.tree import DecisionTreeRegressor
from xgboost import XGBRegressor

from georegression.simulation.simulation import show_sample
from georegression.simulation.simulation_utils import *
from georegression.stacking_model import StackingWeightModel
from georegression.weight_model import WeightModel


# TODO: Explain why the improvement becomes significant when there are more points.


def fit_models(
    X,
    y,
    points,
    stacking_neighbour_count=0.03,
    stacking_neighbour_leave_out_rate=0.15,
    grf_neighbour_count=0.03,
    grf_n_estimators=50,
    gwr_neighbour_count=0.03,
    rf_n_estimators=2000,
    info=None,
):
    if info is None:
        info = {}

    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = {}

    model = StackingWeightModel(
        DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)
    result["Stacking_Base"] = model.llocv_score_
    result["Stacking"] = model.llocv_stacking_
    result["Stacking_Time"] = t2 - t1

    model = StackingWeightModel(
        ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking_Extra:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)
    result["Stacking_Extra_Base"] = model.llocv_score_
    result["Stacking_Extra"] = model.llocv_stacking_
    result["Stacking_Extra_Time"] = t2 - t1

    model = WeightModel(
        RandomForestRegressor(n_estimators=grf_n_estimators),
        distance_measure,
        kernel_type,
        neighbour_count=grf_neighbour_count,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GRF:", model.llocv_score_)
    print(t2 - t1)
    result["GRF"] = model.llocv_score_
    result["GRF_Time"] = t2 - t1

    model = WeightModel(
        LinearRegression(),
        distance_measure,
        kernel_type,
        neighbour_count=gwr_neighbour_count,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GWR:", model.llocv_score_)
    print(t2 - t1)
    result["GWR"] = model.llocv_score_
    result["GWR_Time"] = t2 - t1

    model = RandomForestRegressor(
        oob_score=True, n_estimators=rf_n_estimators, n_jobs=-1
    )
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("RF:", model.oob_score_)
    print(t2 - t1)
    result["RF"] = model.oob_score_
    result["RF_Time"] = t2 - t1

    model = LinearRegression()
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("LR:", model.score(X_plus, y))
    print(t2 - t1)
    result["LR"] = model.score(X_plus, y)
    result["LR_Time"] = t2 - t1

    result = {**result, **info}
    with open("simulation_result.jsonl", "a") as f:
        f.write(json.dumps(result) + "\n")

    return result


def fit_llocv_models(
    X,
    y,
    points,
):
    X_plus = np.concatenate([X, points], axis=1)
    loo = LeaveOneOut()

    n_estimators = 500

    y_predicts = []
    for train, test in loo.split(X_plus):
        estimator = XGBRegressor(n_estimators=n_estimators, n_jobs=-1)
        estimator.fit(X_plus[train], y[train])
        y_predicts.append(estimator.predict(X_plus[test]))

    from sklearn.metrics import r2_score

    score = r2_score(y, y_predicts)

    record = {
        "n_estimators": n_estimators,
        "score": score,
    }

    return record


def coef_auto_gau_weak():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for i in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.5, 5)
        sigma2 = np.random.uniform(0.5, 5)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))
    coef_sin_1 = sine_coefficient(1, np.array([-1, 1]), 1)
    coef_sin_2 = sine_coefficient(1, np.array([1, 1]), 1)
    coef_sin = coefficient_wrapper(np.sum, coef_sin_1, coef_sin_2)
    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), 3)
    coef_gau_2 = gaussian_coefficient(np.array([-5, -5]), 3, amplitude=2)
    coef_gau = coefficient_wrapper(np.sum, coef_gau_1, coef_gau_2)

    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_sin, coef_gau)

    return coef_sum


def f_square(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + 0


def f_square_2(X, C, points):
    return (
        polynomial_function(C[0], 2)(X[:, 0], points)
        + polynomial_function(C[1], 2)(X[:, 1], points)
        + 0
    )


def f_square_const(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + C[0](points) * 10 + 0


def f_sigmoid(X, C, points):
    return sigmoid_function(C[0])(X[:, 0], points) + 0


def f_interact(X, C, points):
    return interaction_function(C[0])(X[:, 0], X[:, 1], points) + 0


def generate_sample(count, f, coef_func, random_seed=1, plot=False):
    np.random.seed(random_seed)
    points = sample_points(count, bounds=(-10, 10))
    x1 = sample_x(count, bounds=(-10, 10))
    coefficients = [coef_func()]

    X = np.stack((x1,), axis=-1)
    y = f(X, coefficients, points)

    if plot:
        folder = f"Plot/{coef_func.__name__}_{f.__name__}_{count}"
        os.makedirs(folder, exist_ok=True)
        show_sample(X, y, points, coefficients, folder)

    return X, y, points


def square_strong_100():
    X, y, points = generate_sample(100, f_square, coef_strong, random_seed=1, plot=True)
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     100,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.3,
        stacking_neighbour_leave_out_rate=0.4,
        grf_neighbour_count=0.3,
        grf_n_estimators=50,
        gwr_neighbour_count=0.5,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 100},
    )


def square_strong_500():
    X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     500,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.3,
        stacking_neighbour_leave_out_rate=0.1,
        grf_neighbour_count=0.3,
        grf_n_estimators=50,
        gwr_neighbour_count=0.2,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 500},
    )


def square_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.01, 0.02, 0.03, 0.04],
    #     [0.1, 0.2, 0.3, 0.4],
    #     [0.01, 0.02, 0.03, 0.04],
    #     [0.01, 0.02, 0.03, 0.04],
    #     1000,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.3,
        grf_neighbour_count=0.02,
        grf_n_estimators=50,
        gwr_neighbour_count=0.03,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 1000},
    )


def square_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     5000,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.015,
        stacking_neighbour_leave_out_rate=0.4,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.015,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 5000},
    )


def square_gau_strong_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.1, 0.15, 0.2, 0.25],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     100,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.45,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.45,
        grf_n_estimators=50,
        gwr_neighbour_count=0.5,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 100},
    )


def square_gau_strong_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.2],
    #     [0.05, 0.1, 0.2],
    #     500,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.08,
        stacking_neighbour_leave_out_rate=0.1,
        grf_neighbour_count=0.1,
        grf_n_estimators=50,
        gwr_neighbour_count=0.1,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 500},
    )


def square_gau_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     1000,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.05,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.04,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 1000},
    )


def square_gau_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     5000,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.008,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.01,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 5000},
    )


def square_gau_strong_10000():
    X, y, points = generate_sample(
        10000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    test_models(
        X,
        y,
        points,
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.05, 0.1, 0.15, 0.2],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        10000,
        "f_square",
        "coef_gau_strong",
    )

    # return fit_models(
    #     X,
    #     y,
    #     points,
    #     stacking_neighbour_count=0.008,
    #     stacking_neighbour_leave_out_rate=0.2,
    #     grf_neighbour_count=0.01,
    #     grf_n_estimators=50,
    #     gwr_neighbour_count=0.01,
    #     rf_n_estimators=2000,
    #     info={"f": "f_square", "coef": "coef_gau_strong", "count": 10000},
    # )


def square_gau_strong_50000():
    X, y, points = generate_sample(
        50000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    test_models(
        X,
        y,
        points,
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.05, 0.1, 0.15, 0.2],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        50000,
        "f_square",
        "coef_gau_strong",
    )

    # return fit_models(
    #     X,
    #     y,
    #     points,
    #     stacking_neighbour_count=0.008,
    #     stacking_neighbour_leave_out_rate=0.2,
    #     grf_neighbour_count=0.01,
    #     grf_n_estimators=50,
    #     gwr_neighbour_count=0.01,
    #     rf_n_estimators=2000,
    #     info={"f": "f_square", "coef": "coef_gau_strong", "count": 50000},
    # )


def square_gau_weak_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.1, 0.15, 0.2, 0.25],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     500,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.25,
        stacking_neighbour_leave_out_rate=0.25,
        grf_neighbour_count=0.08,
        grf_n_estimators=50,
        gwr_neighbour_count=0.3,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 100},
    )


def square_gau_weak_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.2],
    #     [0.05, 0.1, 0.2],
    #     500,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.08,
        stacking_neighbour_leave_out_rate=0.15,
        grf_neighbour_count=0.05,
        grf_n_estimators=50,
        gwr_neighbour_count=0.1,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 500},
    )


def square_gau_weak_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.03, 0.04, 0.05, 0.06],
    #     [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35],
    #     [0.03, 0.04, 0.05, 0.06],
    #     [0.03, 0.04, 0.05, 0.06],
    #     1000,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.05,
        stacking_neighbour_leave_out_rate=0.25,
        grf_neighbour_count=0.06,
        grf_n_estimators=50,
        gwr_neighbour_count=0.06,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 1000},
    )


def square_gau_weak_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     [0.2, 0.25, 0.3],
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     5000,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.3,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 5000},
    )


def square_2_gau_strong_weak_5000():
    np.random.seed(1)

    points = sample_points(5000, bounds=(-10, 10))
    x1 = sample_x(5000, bounds=(-10, 10))
    x2 = sample_x(5000, bounds=(-10, 10))

    f = f_square_2
    coefficients = [
        coefficient_wrapper(partial(np.multiply, 2), coef_auto_gau_strong()),
        coef_auto_gau_weak(),
    ]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.1, 0.15, 0.2, 0.25],
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.02, 0.03, 0.04, 0.05],
    #     5000,
    #     "f_square_2",
    #     "coef_gau_strong2_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.02,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square_2", "coef": "coef_gau_strong2_weak", "count": 5000},
    )


def interact_ale():
    random_seed = 1
    np.random.seed(random_seed)

    def coef_manual_gau():
        coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
        coef_dir = directional_coefficient(np.array([1, 1]))

        coef_gau_1 = gaussian_coefficient(
            np.array([-5, 5]), [[3, 4], [4, 8]], amplitude=-1
        )
        coef_gau_2 = gaussian_coefficient(np.array([-2, -5]), 5, amplitude=2)
        coef_gau_3 = gaussian_coefficient(np.array([8, 3]), 10, amplitude=-1.5)
        coef_gau_4 = gaussian_coefficient(
            np.array([2, 8]), [[3, 0], [0, 15]], amplitude=0.8
        )
        coef_gau_5 = gaussian_coefficient(np.array([5, -10]), 1, amplitude=1)
        coef_gau_6 = gaussian_coefficient(np.array([-10, -10]), 15, amplitude=1.5)
        coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
        coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
        coef_gau = coefficient_wrapper(
            np.sum,
            coef_gau_1,
            coef_gau_2,
            coef_gau_3,
            coef_gau_4,
            coef_gau_5,
            coef_gau_6,
        )

        # coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)
        coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_gau)

        return coef_sum

    count = 5000
    points = sample_points(count, bounds=[[-10, 10], [-10, 10]])

    coef_x1 = coef_auto_gau_weak()
    coef_x2 = coefficient_wrapper(partial(np.multiply, 3), coef_x1)
    x1 = sample_x(count, mean=coef_x1, bounds=(-1, 1), points=points)
    x2 = sample_x(count, mean=coef_x2, bounds=(-2, 2), points=points)

    f = f_interact
    coef_func = coef_auto_gau_strong

    if isinstance(coef_func, list):
        coefficients = [func() for func in coef_func]
    else:
        coefficients = [coef_func()]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    # distance_measure = "euclidean"
    # kernel_type = "bisquare"
    # neighbour_count = 0.03
    # model = WeightModel(
    #     RandomForestRegressor(n_estimators=50),
    #     distance_measure,
    #     kernel_type,
    #     neighbour_count=neighbour_count,
    #     cache_data=False,
    #     cache_estimator=False,
    #     # cache_data=True,
    #     # cache_estimator=True,
    # )
    # model.fit(X, y, [points])
    # print("GRF:", model.llocv_score_)

    # test_models(
    #     X,
    #     y,
    #     points,
    #     [],
    #     # [0.02, 0.03, 0.04, 0.05],
    #     [],
    #     # [0.1, 0.15, 0.2, 0.25],
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.02, 0.03, 0.04, 0.05],
    #     5000,
    #     "f_interact",
    #     "coef_gau_strong2_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.03,
        stacking_neighbour_leave_out_rate=0.15,
        grf_neighbour_count=0.03,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square_2", "coef": "coef_gau_strong2_weak", "count": 5000},
    )


def test_llocv():
    func = f_square

    for count in [100, 500, 1000, 5000]:
        for coef in [coef_strong, coef_auto_gau_strong, coef_auto_gau_weak]:
            X, y, points = generate_sample(count, func, coef, random_seed=1, plot=True)

            result = fit_llocv_models(X, y, points)

            with open("simulation_result_llocv.jsonl", "a") as f:
                result["count"] = count
                result["func"] = func.__name__
                result["coef"] = coef.__name__
                f.write(json.dumps(result) + "\n")


def test_models(
    X,
    y,
    points,
    stacking_neighbour_count,
    stacking_neighbour_leave_out_rate,
    grf_neighbour_count,
    gwr_neighbour_count,
    count,
    func,
    coef,
):
    stacking_params = test_stacking(
        X, y, points, stacking_neighbour_count, stacking_neighbour_leave_out_rate
    )
    grf_params = test_GRF(X, y, points, grf_neighbour_count)
    gwr_params = test_GWR(X, y, points, gwr_neighbour_count)

    with open("simulation_params.jsonl", "a") as f:
        for params in stacking_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in grf_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in gwr_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")

    # Print the param with the best score
    print(max(stacking_params, key=lambda x: x["Stacking"]))
    print(max(grf_params, key=lambda x: x["GRF"]))
    print(max(gwr_params, key=lambda x: x["GWR"]))

    # Output the best result to jsonl
    with open("simulation_param_best.jsonl", "a") as f:
        f.write(json.dumps(max(stacking_params, key=lambda x: x["Stacking"])) + "\n")
        f.write(json.dumps(max(grf_params, key=lambda x: x["GRF"])) + "\n")
        f.write(json.dumps(max(gwr_params, key=lambda x: x["GWR"])) + "\n")


def test_GRF(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                RandomForestRegressor(n_estimators=50),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X, y, [points])
            print("GRF:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GRF": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


def test_stacking(X, y, points, neighbour_counts, leave_out_rates):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=X.shape[1])

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            for leave_out_rate in leave_out_rates:
                model = StackingWeightModel(
                    local_estimator,
                    distance_measure,
                    kernel_type,
                    neighbour_count=neighbour_count,
                    neighbour_leave_out_rate=leave_out_rate,
                )
                if use_x_plus:
                    model.fit(X_plus, y, [points])
                else:
                    model.fit(X, y, [points])
                print(
                    "Stacking:",
                    model.llocv_score_,
                    model.llocv_stacking_,
                    "neighbour_count:",
                    neighbour_count,
                    "leave_out_rate:",
                    leave_out_rate,
                    "use_x_plus:",
                    use_x_plus,
                )
                result.append(
                    {
                        "Stacking_Base": model.llocv_score_,
                        "Stacking": model.llocv_stacking_,
                        "neighbour_count": neighbour_count,
                        "leave_out_rate": leave_out_rate,
                        "use_x_plus": use_x_plus,
                    }
                )

    return result


def test_GWR(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                LinearRegression(),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X, y, [points])
            print("GWR:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GWR": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


if __name__ == "__main__":
    # square_strong_100()
    # square_strong_500()
    # square_strong_1000()
    # square_strong_5000()
    # square_gau_strong_100()
    # square_gau_strong_500()
    # square_gau_strong_1000()
    # square_gau_strong_5000()
    # square_gau_weak_100()
    # square_gau_weak_500()
    # square_gau_weak_1000()
    # square_gau_weak_5000()
    # square_2_gau_strong_weak_5000()

    # test_llocv()

    # square_gau_strong_10000()
    # square_gau_strong_50000()

    interact_ale()

    pass


import time
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from georegression.local_ale import weighted_ale

from georegression.stacking_model import StackingWeightModel
from georegression.simulation.simulation import show_sample
from georegression.visualize.ale import plot_ale
from georegression.weight_model import WeightModel
from georegression.simulation.simulation_utils import *


def coef_manual_gau():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), [[3, 4], [4, 8]], amplitude=-1)
    coef_gau_2 = gaussian_coefficient(np.array([-2, -5]), 5, amplitude=2)
    coef_gau_3 = gaussian_coefficient(np.array([8, 3]), 10, amplitude=-1.5)
    coef_gau_4 = gaussian_coefficient(
        np.array([2, 8]), [[3, 0], [0, 15]], amplitude=0.8
    )
    coef_gau_5 = gaussian_coefficient(np.array([5, -10]), 1, amplitude=1)
    coef_gau_6 = gaussian_coefficient(np.array([-10, -10]), 15, amplitude=1.5)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
    coef_gau = coefficient_wrapper(
        np.sum, coef_gau_1, coef_gau_2, coef_gau_3, coef_gau_4, coef_gau_5, coef_gau_6
    )

    # coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_gau)

    return coef_sum


def coef_auto_gau_weak():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for i in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.5, 5)
        sigma2 = np.random.uniform(0.5, 5)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    coef_sum = coefficient_wrapper(partial(np.multiply, 2), coef_sum)

    return coef_sum


def f_square_2(X, C, points):
    return (
        polynomial_function(C[0], 2)(X[:, 0], points)
        + polynomial_function(C[1], 2)(X[:, 1], points)
        + 0
    )


def generate_sample(count, f, coef_func, random_seed=1, plot=False):
    np.random.seed(random_seed)
    points = sample_points(count, bounds=(-10, 10))
    x1 = sample_x(count, bounds=(-10, 10))
    x2 = sample_x(count, bounds=(-10, 10))

    if isinstance(coef_func, list):
        coefficients = [func() for func in coef_func]
    else:
        coefficients = [coef_func()]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    if plot:
        show_sample(X, y, points, coefficients)

    return X, y, points


def draw_graph():
    X, y, points = generate_sample(
        count=5000, f=f_square_2, coef_func=[coef_auto_gau_strong, coef_auto_gau_weak], random_seed=1,
        plot=True
    )
    X_plus = np.concatenate([X, points], axis=1)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.02

    model = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        cache_data=True,
        cache_estimator=True,
    )
    model.fit(X, y, [points])
    print("GRF:", model.llocv_score_)

    importance_global = model.importance_score_global()
    print(importance_global)

    importance_local = model.importance_score_local()
    print(importance_local)

    # Normalize the local importance to [0, 1]
    # importance_local = (importance_local - importance_local.min(axis=0)) / (
    #     importance_local.max(axis=0) - importance_local.min(axis=0)
    # )
    importance_local = (importance_local - importance_local.min(axis=1)) / (
        importance_local.max(axis=1) - importance_local.min(axis=1)
    )

    # Plot the local importance
    for i in range(importance_local.shape[1]):
        fig = plt.figure()
        scatter = plt.scatter(
            points[:, 0], points[:, 1], c=importance_local[:, i], cmap="viridis"
        )
        fig.colorbar(scatter)
        fig.savefig(f"Plot/Local_importance_{i}.png")
        fig.show()


def fit_stacking():
    X, y, points = generate_sample(
        count=5000, f=f_square_2, coef_func=[coef_auto_gau_strong, coef_auto_gau_weak], random_seed=1,
        plot=True
    )
    X_plus = np.concatenate([X, points], axis=1)
    distance_measure = "euclidean"
    kernel_type = "bisquare"

    model = StackingWeightModel(
        # ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]),
        DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=0.02,
        neighbour_leave_out_rate=0.2,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)


if __name__ == "__main__":
    draw_graph()
    # fit_stacking()


"""
Copied from `simulation_for_fitting`.
Modification: Split the data into training and testing set to validate the generalization of the model.
"""

import json
import os
import time
from functools import partial

from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneOut
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from xgboost import XGBRegressor

from georegression.simulation.simulation import show_sample
from georegression.simulation.simulation_utils import *
from georegression.stacking_model import StackingWeightModel
from georegression.weight_model import WeightModel



def fit_models(
    X,
    y,
    points,
    X_test,
    y_test,
    points_test,
    stacking_neighbour_count=0.03,
    stacking_neighbour_leave_out_rate=0.15,
    grf_neighbour_count=0.03,
    grf_n_estimators=50,
    gwr_neighbour_count=0.03,
    rf_n_estimators=2000,
    info=None,
):
    if info is None:
        info = {}

    X_plus = np.concatenate([X, points], axis=1)
    # X_train, X_test, y_train, y_test, points_train, points_test = train_test_split(X_plus, y, points, test_size=0.2, random_state=1)
    X_plus_test = np.concatenate([X_test, points_test], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = {}

    model = StackingWeightModel(
        DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)
    result["Stacking_Base"] = model.llocv_score_
    result["Stacking"] = model.llocv_stacking_
    result["Stacking_Time"] = t2 - t1

    prediction = model.predict_by_fit(X_plus, y, [points], X_plus_test, [points_test])
    score = r2_score(y_test, prediction)
    print("Stacking Prediction:", score)

    model.fit(X_plus_test, y_test, [points_test])
    print("Stacking Fit on Test:", model.llocv_score_, model.llocv_stacking_)

    model = StackingWeightModel(
        ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("Stacking_Extra:", model.llocv_score_, model.llocv_stacking_)
    print(t2 - t1)
    result["Stacking_Extra_Base"] = model.llocv_score_
    result["Stacking_Extra"] = model.llocv_stacking_
    result["Stacking_Extra_Time"] = t2 - t1

    prediction = model.predict_by_fit(X_plus, y, [points], X_plus_test, [points_test])
    score = r2_score(y_test, prediction)
    print("Stacking_Extra Prediction:", score)

    model.fit(X_plus_test, y_test, [points_test])
    print("Stacking_Extra Fit on Test:", model.llocv_score_, model.llocv_stacking_)

    model = WeightModel(
        RandomForestRegressor(n_estimators=grf_n_estimators),
        distance_measure,
        kernel_type,
        neighbour_count=grf_neighbour_count,
        cache_data=True,
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GRF:", model.llocv_score_)
    print(t2 - t1)
    result["GRF"] = model.llocv_score_
    result["GRF_Time"] = t2 - t1

    prediction = model.predict_by_fit(X_plus_test, [points_test])
    score = r2_score(y_test, prediction)
    print("GRF Prediction:", score)

    model.fit(X_plus_test, y_test, [points_test])
    print("GRF Fit on Test:", model.llocv_score_)

    model = WeightModel(
        LinearRegression(),
        distance_measure,
        kernel_type,
        neighbour_count=gwr_neighbour_count,
        cache_data=True
    )
    t1 = time.time()
    model.fit(X_plus, y, [points])
    t2 = time.time()
    print("GWR:", model.llocv_score_)
    print(t2 - t1)
    result["GWR"] = model.llocv_score_
    result["GWR_Time"] = t2 - t1

    prediction = model.predict_by_fit(X_plus_test, [points_test])
    score = r2_score(y_test, prediction)
    print("GWR Prediction:", score)

    model.fit(X_plus_test, y_test, [points_test])
    print("GWR Fit on Test:", model.llocv_score_)

    model = RandomForestRegressor(
        oob_score=True, n_estimators=rf_n_estimators, n_jobs=-1
    )
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("RF:", model.oob_score_)
    print(t2 - t1)
    result["RF"] = model.oob_score_
    result["RF_Time"] = t2 - t1

    prediction = model.predict(X_plus_test)
    score = r2_score(y_test, prediction)
    print("RF Prediction:", score)

    model.fit(X_plus_test, y_test)
    print("RF Fit on Test:", model.oob_score_)

    # loo = LeaveOneOut()
    # y_predicts = []
    # for train, test in loo.split(X_plus):
    #     estimator = XGBRegressor(n_estimators=rf_n_estimators, n_jobs=-1)
    #     estimator.fit(X_plus[train], y[train])
    #     y_predicts.append(estimator.predict(X_plus[test]))
    # score = r2_score(y, y_predicts)
    # print("XGB:", score)

    model = XGBRegressor(n_estimators=rf_n_estimators, n_jobs=-1)
    model.fit(X_plus, y)
    prediction = model.predict(X_plus_test)
    score = r2_score(y_test, prediction)
    print("XGB Prediction:", score)

    model = LinearRegression()
    t1 = time.time()
    model.fit(X_plus, y)
    t2 = time.time()
    print("LR:", model.score(X_plus, y))
    print(t2 - t1)
    result["LR"] = model.score(X_plus, y)
    result["LR_Time"] = t2 - t1

    prediction = model.predict(X_plus_test)
    score = r2_score(y_test, prediction)
    print("LR Prediction:", score)

    model.fit(X_plus_test, y_test)
    print("LR Fit on Test:", model.score(X_plus_test, y_test))

    result = {**result, **info}
    with open("simulation_result.jsonl", "a") as f:
        f.write(json.dumps(result) + "\n")

    return result


def fit_llocv_models(
    X,
    y,
    points,
):
    X_plus = np.concatenate([X, points], axis=1)
    loo = LeaveOneOut()

    n_estimators = 500

    y_predicts = []
    for train, test in loo.split(X_plus):
        estimator = XGBRegressor(n_estimators=n_estimators, n_jobs=-1)
        estimator.fit(X_plus[train], y[train])
        y_predicts.append(estimator.predict(X_plus[test]))

    from sklearn.metrics import r2_score

    score = r2_score(y, y_predicts)

    record = {
        "n_estimators": n_estimators,
        "score": score,
    }

    return record


def coef_auto_gau_weak():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for i in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.5, 5)
        sigma2 = np.random.uniform(0.5, 5)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def coef_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))
    coef_sin_1 = sine_coefficient(1, np.array([-1, 1]), 1)
    coef_sin_2 = sine_coefficient(1, np.array([1, 1]), 1)
    coef_sin = coefficient_wrapper(np.sum, coef_sin_1, coef_sin_2)
    coef_gau_1 = gaussian_coefficient(np.array([-5, 5]), 3)
    coef_gau_2 = gaussian_coefficient(np.array([-5, -5]), 3, amplitude=2)
    coef_gau = coefficient_wrapper(np.sum, coef_gau_1, coef_gau_2)

    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_sin, coef_gau)

    return coef_sum


def f_square(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + 0


def f_square_2(X, C, points):
    return (
        polynomial_function(C[0], 2)(X[:, 0], points)
        + polynomial_function(C[1], 2)(X[:, 1], points)
        + 0
    )


def f_square_const(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + C[0](points) * 10 + 0


def f_sigmoid(X, C, points):
    return sigmoid_function(C[0])(X[:, 0], points) + 0


def f_interact(X, C, points):
    return interaction_function(C[0])(X[:, 0], X[:, 1], points) + 0


def generate_sample(count, f, coef_func, random_seed=1, plot=False):
    np.random.seed(random_seed)
    points = sample_points(count, bounds=(-10, 10))
    x1 = sample_x(count, bounds=(-10, 10))
    coefficients = [coef_func()]

    X = np.stack((x1,), axis=-1)
    y = f(X, coefficients, points)

    if plot:
        folder = f"Plot/{coef_func.__name__}_{f.__name__}_{count}"
        os.makedirs(folder, exist_ok=True)
        show_sample(X, y, points, coefficients, folder)

    return X, y, points


def square_strong_100():
    X, y, points = generate_sample(100, f_square, coef_strong, random_seed=1, plot=True)
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     100,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.3,
        stacking_neighbour_leave_out_rate=0.4,
        grf_neighbour_count=0.3,
        grf_n_estimators=50,
        gwr_neighbour_count=0.5,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 100},
    )


def square_strong_500():
    X, y, points = generate_sample(500, f_square, coef_strong, random_seed=1, plot=True)
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     500,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.3,
        stacking_neighbour_leave_out_rate=0.1,
        grf_neighbour_count=0.3,
        grf_n_estimators=50,
        gwr_neighbour_count=0.2,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 500},
    )


def square_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_strong, random_seed=1, plot=True
    )
    X_test, y_test, points_test = generate_sample(
        1000, f_square, coef_strong, random_seed=2, plot=False
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.01, 0.02, 0.03, 0.04],
    #     [0.1, 0.2, 0.3, 0.4],
    #     [0.01, 0.02, 0.03, 0.04],
    #     [0.01, 0.02, 0.03, 0.04],
    #     1000,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        X_test, y_test, points_test,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.3,
        grf_neighbour_count=0.02,
        grf_n_estimators=50,
        gwr_neighbour_count=0.03,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 1000},
    )


def square_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_strong, random_seed=1, plot=True
    )
    X_test, y_test, points_test = generate_sample(
        5000, f_square, coef_strong, random_seed=2, plot=False
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.1, 0.2, 0.3, 0.4, 0.5],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     5000,
    #     "f_square",
    #     "coef_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        X_test, y_test, points_test,
        stacking_neighbour_count=0.015,
        stacking_neighbour_leave_out_rate=0.4,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.015,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_strong", "count": 5000},
    )


def square_gau_strong_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.1, 0.15, 0.2, 0.25],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     100,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.45,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.45,
        grf_n_estimators=50,
        gwr_neighbour_count=0.5,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 100},
    )


def square_gau_strong_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.2],
    #     [0.05, 0.1, 0.2],
    #     500,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.08,
        stacking_neighbour_leave_out_rate=0.1,
        grf_neighbour_count=0.1,
        grf_n_estimators=50,
        gwr_neighbour_count=0.1,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 500},
    )


def square_gau_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     [0.01, 0.02, 0.03, 0.04, 0.05],
    #     1000,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.05,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.04,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 1000},
    )


def square_gau_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     [0.003, 0.005, 0.008, 0.01, 0.015, 0.02],
    #     5000,
    #     "f_square",
    #     "coef_gau_strong",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.008,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.01,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 5000},
    )


def square_gau_strong_10000():
    X, y, points = generate_sample(
        10000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    test_models(
        X,
        y,
        points,
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.05, 0.1, 0.15, 0.2],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        10000,
        "f_square",
        "coef_gau_strong",
    )

    # return fit_models(
    #     X,
    #     y,
    #     points,
    #     stacking_neighbour_count=0.008,
    #     stacking_neighbour_leave_out_rate=0.2,
    #     grf_neighbour_count=0.01,
    #     grf_n_estimators=50,
    #     gwr_neighbour_count=0.01,
    #     rf_n_estimators=2000,
    #     info={"f": "f_square", "coef": "coef_gau_strong", "count": 10000},
    # )


def square_gau_strong_50000():
    X, y, points = generate_sample(
        50000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )
    test_models(
        X,
        y,
        points,
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.05, 0.1, 0.15, 0.2],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        [0.001, 0.002, 0.003, 0.005, 0.008, 0.01, 0.012, 0.015],
        50000,
        "f_square",
        "coef_gau_strong",
    )

    # return fit_models(
    #     X,
    #     y,
    #     points,
    #     stacking_neighbour_count=0.008,
    #     stacking_neighbour_leave_out_rate=0.2,
    #     grf_neighbour_count=0.01,
    #     grf_n_estimators=50,
    #     gwr_neighbour_count=0.01,
    #     rf_n_estimators=2000,
    #     info={"f": "f_square", "coef": "coef_gau_strong", "count": 50000},
    # )


def square_gau_weak_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.1, 0.15, 0.2, 0.25],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     [0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
    #     500,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.25,
        stacking_neighbour_leave_out_rate=0.25,
        grf_neighbour_count=0.08,
        grf_n_estimators=50,
        gwr_neighbour_count=0.3,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 100},
    )


def square_gau_weak_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.05, 0.08, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.15, 0.2],
    #     [0.05, 0.1, 0.2],
    #     [0.05, 0.1, 0.2],
    #     500,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.08,
        stacking_neighbour_leave_out_rate=0.15,
        grf_neighbour_count=0.05,
        grf_n_estimators=50,
        gwr_neighbour_count=0.1,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 500},
    )


def square_gau_weak_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.03, 0.04, 0.05, 0.06],
    #     [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35],
    #     [0.03, 0.04, 0.05, 0.06],
    #     [0.03, 0.04, 0.05, 0.06],
    #     1000,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.05,
        stacking_neighbour_leave_out_rate=0.25,
        grf_neighbour_count=0.06,
        grf_n_estimators=50,
        gwr_neighbour_count=0.06,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 1000},
    )


def square_gau_weak_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_weak, random_seed=1, plot=True
    )
    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     [0.2, 0.25, 0.3],
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     [0.008, 0.01, 0.015, 0.02, 0.025],
    #     5000,
    #     "f_square",
    #     "coef_gau_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.3,
        grf_neighbour_count=0.01,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square", "coef": "coef_gau_weak", "count": 5000},
    )


def square_2_gau_strong_weak_5000():
    np.random.seed(1)

    points = sample_points(5000, bounds=(-10, 10))
    x1 = sample_x(5000, bounds=(-10, 10))
    x2 = sample_x(5000, bounds=(-10, 10))

    f = f_square_2
    coefficients = [
        coefficient_wrapper(partial(np.multiply, 2), coef_auto_gau_strong()),
        coef_auto_gau_weak(),
    ]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    # test_models(
    #     X,
    #     y,
    #     points,
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.1, 0.15, 0.2, 0.25],
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.02, 0.03, 0.04, 0.05],
    #     5000,
    #     "f_square_2",
    #     "coef_gau_strong2_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.2,
        grf_neighbour_count=0.02,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square_2", "coef": "coef_gau_strong2_weak", "count": 5000},
    )


def interact_ale():
    random_seed = 1
    np.random.seed(random_seed)

    def coef_manual_gau():
        coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
        coef_dir = directional_coefficient(np.array([1, 1]))

        coef_gau_1 = gaussian_coefficient(
            np.array([-5, 5]), [[3, 4], [4, 8]], amplitude=-1
        )
        coef_gau_2 = gaussian_coefficient(np.array([-2, -5]), 5, amplitude=2)
        coef_gau_3 = gaussian_coefficient(np.array([8, 3]), 10, amplitude=-1.5)
        coef_gau_4 = gaussian_coefficient(
            np.array([2, 8]), [[3, 0], [0, 15]], amplitude=0.8
        )
        coef_gau_5 = gaussian_coefficient(np.array([5, -10]), 1, amplitude=1)
        coef_gau_6 = gaussian_coefficient(np.array([-10, -10]), 15, amplitude=1.5)
        coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
        coef_gau_6 = gaussian_coefficient(np.array([-11, 0]), 5, amplitude=2)
        coef_gau = coefficient_wrapper(
            np.sum,
            coef_gau_1,
            coef_gau_2,
            coef_gau_3,
            coef_gau_4,
            coef_gau_5,
            coef_gau_6,
        )

        # coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)
        coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_gau)

        return coef_sum

    count = 5000
    points = sample_points(count, bounds=[[-10, 10], [-10, 10]])

    coef_x1 = coef_auto_gau_weak()
    coef_x2 = coefficient_wrapper(partial(np.multiply, 3), coef_x1)
    x1 = sample_x(count, mean=coef_x1, bounds=(-1, 1), points=points)
    x2 = sample_x(count, mean=coef_x2, bounds=(-2, 2), points=points)

    f = f_interact
    coef_func = coef_auto_gau_strong

    if isinstance(coef_func, list):
        coefficients = [func() for func in coef_func]
    else:
        coefficients = [coef_func()]

    X = np.stack((x1, x2), axis=-1)
    y = f(X, coefficients, points)

    # distance_measure = "euclidean"
    # kernel_type = "bisquare"
    # neighbour_count = 0.03
    # model = WeightModel(
    #     RandomForestRegressor(n_estimators=50),
    #     distance_measure,
    #     kernel_type,
    #     neighbour_count=neighbour_count,
    #     cache_data=False,
    #     cache_estimator=False,
    #     # cache_data=True,
    #     # cache_estimator=True,
    # )
    # model.fit(X, y, [points])
    # print("GRF:", model.llocv_score_)

    # test_models(
    #     X,
    #     y,
    #     points,
    #     [],
    #     # [0.02, 0.03, 0.04, 0.05],
    #     [],
    #     # [0.1, 0.15, 0.2, 0.25],
    #     [0.02, 0.03, 0.04, 0.05],
    #     [0.02, 0.03, 0.04, 0.05],
    #     5000,
    #     "f_interact",
    #     "coef_gau_strong2_weak",
    # )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.03,
        stacking_neighbour_leave_out_rate=0.15,
        grf_neighbour_count=0.03,
        grf_n_estimators=50,
        gwr_neighbour_count=0.02,
        rf_n_estimators=2000,
        info={"f": "f_square_2", "coef": "coef_gau_strong2_weak", "count": 5000},
    )


def test_llocv():
    func = f_square

    for count in [100, 500, 1000, 5000]:
        for coef in [coef_strong, coef_auto_gau_strong, coef_auto_gau_weak]:
            X, y, points = generate_sample(count, func, coef, random_seed=1, plot=True)

            result = fit_llocv_models(X, y, points)

            with open("simulation_result_llocv.jsonl", "a") as f:
                result["count"] = count
                result["func"] = func.__name__
                result["coef"] = coef.__name__
                f.write(json.dumps(result) + "\n")


def test_models(
    X,
    y,
    points,
    stacking_neighbour_count,
    stacking_neighbour_leave_out_rate,
    grf_neighbour_count,
    gwr_neighbour_count,
    count,
    func,
    coef,
):
    stacking_params = test_stacking(
        X, y, points, stacking_neighbour_count, stacking_neighbour_leave_out_rate
    )
    grf_params = test_GRF(X, y, points, grf_neighbour_count)
    gwr_params = test_GWR(X, y, points, gwr_neighbour_count)

    with open("simulation_params.jsonl", "a") as f:
        for params in stacking_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in grf_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")
        for params in gwr_params:
            params["count"] = count
            params["func"] = func
            params["coef"] = coef
            f.write(json.dumps(params) + "\n")

    # Print the param with the best score
    print(max(stacking_params, key=lambda x: x["Stacking"]))
    print(max(grf_params, key=lambda x: x["GRF"]))
    print(max(gwr_params, key=lambda x: x["GWR"]))

    # Output the best result to jsonl
    with open("simulation_param_best.jsonl", "a") as f:
        f.write(json.dumps(max(stacking_params, key=lambda x: x["Stacking"])) + "\n")
        f.write(json.dumps(max(grf_params, key=lambda x: x["GRF"])) + "\n")
        f.write(json.dumps(max(gwr_params, key=lambda x: x["GWR"])) + "\n")


def test_GRF(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                RandomForestRegressor(n_estimators=50),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X, y, [points])
            print("GRF:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GRF": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


def test_stacking(X, y, points, neighbour_counts, leave_out_rates):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=X.shape[1])

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            for leave_out_rate in leave_out_rates:
                model = StackingWeightModel(
                    local_estimator,
                    distance_measure,
                    kernel_type,
                    neighbour_count=neighbour_count,
                    neighbour_leave_out_rate=leave_out_rate,
                )
                if use_x_plus:
                    model.fit(X_plus, y, [points])
                else:
                    model.fit(X, y, [points])
                print(
                    "Stacking:",
                    model.llocv_score_,
                    model.llocv_stacking_,
                    "neighbour_count:",
                    neighbour_count,
                    "leave_out_rate:",
                    leave_out_rate,
                    "use_x_plus:",
                    use_x_plus,
                )
                result.append(
                    {
                        "Stacking_Base": model.llocv_score_,
                        "Stacking": model.llocv_stacking_,
                        "neighbour_count": neighbour_count,
                        "leave_out_rate": leave_out_rate,
                        "use_x_plus": use_x_plus,
                    }
                )

    return result


def test_GWR(X, y, points, neighbour_counts):
    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = []

    for use_x_plus in [True, False]:
        for neighbour_count in neighbour_counts:
            model = WeightModel(
                LinearRegression(),
                distance_measure,
                kernel_type,
                neighbour_count=neighbour_count,
            )
            if use_x_plus:
                model.fit(X_plus, y, [points])
            else:
                model.fit(X, y, [points])
            print("GWR:", model.llocv_score_, neighbour_count, use_x_plus)
            result.append(
                {
                    "GWR": model.llocv_score_,
                    "neighbour_count": neighbour_count,
                    "use_x_plus": use_x_plus,
                }
            )

    return result


if __name__ == "__main__":
    # square_strong_100()
    # square_strong_500()
    # square_strong_1000()
    square_strong_5000()
    # square_gau_strong_100()
    # square_gau_strong_500()
    # square_gau_strong_1000()
    # square_gau_strong_5000()
    # square_gau_weak_100()
    # square_gau_weak_500()
    # square_gau_weak_1000()
    # square_gau_weak_5000()
    # square_2_gau_strong_weak_5000()

    # test_llocv()

    # square_gau_strong_10000()
    # square_gau_strong_50000()

    # interact_ale()

    pass


import json
import os
import time
from functools import partial

from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneOut
from sklearn.tree import DecisionTreeRegressor
from xgboost import XGBRegressor

from georegression.simulation.simulation import show_sample
from georegression.simulation.simulation_utils import *
from georegression.stacking_model import StackingWeightModel
from georegression.weight_model import WeightModel


# TODO: Explain why the improvement becomes significant when there are more points.


def fit_models(
    X,
    y,
    points,
    stacking_neighbour_count=0.03,
    stacking_neighbour_leave_out_rate=0.15,
    info=None,
):
    if info is None:
        info = {}

    X_plus = np.concatenate([X, points], axis=1)

    distance_measure = "euclidean"
    kernel_type = "bisquare"

    result = {}

    model = StackingWeightModel(
        DecisionTreeRegressor(splitter="random", max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )

    repeats = 10
    stackings = []
    for i in range(repeats):
        t1 = time.time()
        model.fit(X_plus, y, [points])
        t2 = time.time()

        stackings.append(model.llocv_stacking_)

        print("Stacking:", model.llocv_score_, model.llocv_stacking_)
        print(t2 - t1)

        result[f"Stacking_Base_{i}"] = model.llocv_score_
        result[f"Stacking_{i}"] = model.llocv_stacking_
        result[f"Stacking_Time"] = t2 - t1

    result["Stacking_Mean"] = np.mean(stackings)
    result["Stacking_Std"] = np.std(stackings)
    result["Stacking_Max"] = np.max(stackings)
    result["Stacking_Min"] = np.min(stackings)

    model = StackingWeightModel(
        ExtraTreesRegressor(n_estimators=10, max_depth=X.shape[1]),
        distance_measure,
        kernel_type,
        neighbour_count=stacking_neighbour_count,
        neighbour_leave_out_rate=stacking_neighbour_leave_out_rate,
    )

    stackings_extra = []
    for i in range(repeats):
        t1 = time.time()
        model.fit(X_plus, y, [points])
        t2 = time.time()

        stackings_extra.append(model.llocv_stacking_)

        print("Stacking_Extra:", model.llocv_score_, model.llocv_stacking_)
        print(t2 - t1)

        result[f"Stacking_Extra_Base_{i}"] = model.llocv_score_
        result[f"Stacking_Extra_{i}"] = model.llocv_stacking_
        result[f"Stacking_Extra_Time"] = t2 - t1

    result["Stacking_Extra_Mean"] = np.mean(stackings_extra)
    result["Stacking_Extra_Std"] = np.std(stackings_extra)
    result["Stacking_Extra_Max"] = np.max(stackings_extra)
    result["Stacking_Extra_Min"] = np.min(stackings_extra)

    result = {**result, **info}
    with open("simulation_variation.jsonl", "a") as f:
        f.write(json.dumps(result) + "\n")

    return result


def coef_auto_gau_strong():
    coef_radial = radial_coefficient(np.array([0, 0]), 1 / np.sqrt(200))
    coef_dir = directional_coefficient(np.array([1, 1]))

    gau_coef_list = []
    for _ in range(1000):
        # Randomly generate the parameters for gaussian coefficient
        center = np.random.uniform(-10, 10, 2)
        amplitude = np.random.uniform(1, 2)
        sign = np.random.choice([-1, 1])
        amplitude *= sign
        sigma1 = np.random.uniform(0.2, 1)
        sigma2 = np.random.uniform(0.2, 1)
        cov = np.random.uniform(-np.sqrt(sigma1 * sigma2), np.sqrt(sigma1 * sigma2))
        sigma = np.array([[sigma1, cov], [cov, sigma2]])

        coef_gau = gaussian_coefficient(center, sigma, amplitude=amplitude)
        gau_coef_list.append(coef_gau)

    coef_gau = coefficient_wrapper(np.sum, *gau_coef_list)
    coef_sum = coefficient_wrapper(np.sum, coef_radial, coef_dir, coef_gau)

    return coef_sum


def f_square(X, C, points):
    return polynomial_function(C[0], 2)(X[:, 0], points) + 0


def generate_sample(count, f, coef_func, random_seed=1, plot=False):
    np.random.seed(random_seed)
    points = sample_points(count, bounds=(-10, 10))
    x1 = sample_x(count, bounds=(-10, 10))
    coefficients = [coef_func()]

    X = np.stack((x1,), axis=-1)
    y = f(X, coefficients, points)

    if plot:
        folder = f"Plot/{coef_func.__name__}_{f.__name__}_{count}"
        os.makedirs(folder, exist_ok=True)
        show_sample(X, y, points, coefficients, folder)

    return X, y, points


def square_gau_strong_100():
    X, y, points = generate_sample(
        100, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.45,
        stacking_neighbour_leave_out_rate=0.2,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 100},
    )


def square_gau_strong_500():
    X, y, points = generate_sample(
        500, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.08,
        stacking_neighbour_leave_out_rate=0.1,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 500},
    )


def square_gau_strong_1000():
    X, y, points = generate_sample(
        1000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.02,
        stacking_neighbour_leave_out_rate=0.05,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 1000},
    )


def square_gau_strong_5000():
    X, y, points = generate_sample(
        5000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.008,
        stacking_neighbour_leave_out_rate=0.2,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 5000},
    )


def square_gau_strong_10000():
    X, y, points = generate_sample(
        10000, f_square, coef_auto_gau_strong, random_seed=1, plot=True
    )

    return fit_models(
        X,
        y,
        points,
        stacking_neighbour_count=0.008,
        stacking_neighbour_leave_out_rate=0.2,
        info={"f": "f_square", "coef": "coef_gau_strong", "count": 10000},
    )





if __name__ == "__main__":
    # square_gau_strong_100()
    # square_gau_strong_500()
    square_gau_strong_1000()
    # square_gau_strong_5000()
    # square_gau_strong_10000()

    pass


import numpy as np
from scipy.stats import multivariate_normal

def gaussian_coefficient(mean, cov, amplitude=1):
    """
    Generate a gaussian coefficient for a given mean and variance.
    """

    if isinstance(cov, list):
        cov = np.array(cov)

    normalize_factor = multivariate_normal.pdf(mean, mean, cov)

    def coefficient(point):
        return multivariate_normal.pdf(point, mean, cov) * amplitude / normalize_factor

    return coefficient


def radial_coefficient(origin, amplitude):
    """
    Generate a radial coefficient for a given origin.
    """

    def coefficient(point):
        return np.linalg.norm(point - origin, axis=-1) * amplitude

    return coefficient


def directional_coefficient(direction, amplitude=1):
    """
    Generate a directional coefficient for a given direction.
    """

    # Normalize the direction vector
    direction = direction / np.linalg.norm(direction)

    def coefficient(point):
        return np.dot(point, direction) / np.linalg.norm(point, axis=-1) * amplitude

    return coefficient


def sine_coefficient(frequency, direction, amplitude):
    """
    Generate a sine coefficient for a given frequency.
    """

    def coefficient(point):
        return np.sin(np.dot(point, direction) * frequency) * amplitude

    return coefficient


def coefficient_wrapper(operator, *coefficients):
    """
    Wrap two coefficients with an operator.
    """
    if len(coefficients) == 1:
        def coefficient_func(point):
            return operator(coefficients[0](point))

        return coefficient_func

    elif coefficients is not None:
        def coefficient_func(point):
            return operator(np.column_stack([c(point) for c in coefficients]), axis=-1)

        return coefficient_func

    return None


def polynomial_function(coefficient, degree):
    def function(x, point):
        return coefficient(point) * x ** degree

    return function


def square_function(coefficient):
    return polynomial_function(coefficient, 2)


def linear_function(coefficient):
    return polynomial_function(coefficient, 1)


def interaction_function(coefficient):
    def function(x1, x2, point):
        return coefficient(point) * x1 * x2

    return function


def sigmoid_function(coefficient):
    def function(x, point):
        return coefficient(point) * np.tanh(x)

    return function


def relu_function(coefficient):
    def function(x, point):
        return coefficient(point) * np.maximum(0, x)

    return function


def exponential_function(coefficient):
    def function(x, point):
        return coefficient(point) * np.exp(x)

    return function

def coefficient_function(coefficient):
    def function(point):
        return coefficient(point)
    
    return function


def sample_points(n, dim=2, bounds=(-1, 1), numeraical_type="continuous"):
    """
    Sample n points in dim dimensions from a uniform distribution with given bounds.
    The bound of each dimension can be sampled from continuous range or discrete classes.
    """

    points = np.zeros((n, dim))

    if not isinstance(bounds, list):
        bounds = [bounds] * dim

    if not isinstance(numeraical_type, list):
        numeraical_type = [numeraical_type] * dim

    for i in range(dim):
        if numeraical_type[i] == "continuous":
            points[:, i] = np.random.uniform(bounds[i][0], bounds[i][1], n)
        elif numeraical_type[i] == "discrete":
            points[:, i] = np.random.choice(bounds[i], n)

    return points


def sample_x(n, type='uniform', bounds=(-1, 1), mean=0, variance=1, scale=1, points=None):
    """
    Sample n x values from a specified distribution.
    If points are given, the parameter of distribution can be determined by coefficient functions,
    which introduce the spatial dependency.
    """
    if callable(bounds[0]):
        bounds[0] = bounds[0](points)

    if callable(bounds[1]):
        bounds[1] = bounds[1](points)

    if callable(mean):
        mean = mean(points)
    
    if callable(variance):
        variance = variance(points)

    if callable(scale):
        scale = scale(points)


    if type == 'uniform':
        base = np.random.uniform(0, 1, n)
        return base * (bounds[1] - bounds[0]) + bounds[0] + mean
    elif type == 'normal':
        base = np.random.normal(0, 1, n)
        return base * variance + mean
    elif type == 'exponential':
        base = np.random.exponential(1, n)
        return base * scale




from importlib.resources import path

import pandas as pd

random_state = 1003
TOD_weekend_field_list = [
    'Departure interval (seconds)', 'Operating time (min)', 'Average distance between stations(m)',
    'Building density',
    'Residential land proportion', 'Education, health, cultural facilities land proportion',
    'Industrial land proportion',
    'Average distance to training institutions(m)',
    'Average distance to leisure spaces(m)',
    'Average distance to retail/hotel and catering places(m)',
    'Number of red shared bicycle sites', 'Road network betweenness', 'Educated people proportion',
    'ITI * INI '
]
TOD_workday_field_list = [
    'Departure interval (seconds)', 'Operating time (min)',
    'Accessible metro stations number within 20-minutes', 'Building density',
    'Residential land proportion',
    'Industrial land proportion', 'Average distance to training institutions(m)',
    'Average distance to leisure spaces(m)',
    'Average distance to retail/hotel and catering places(m)',
    'Number of red shared bicycle sites',
    'Road network betweenness', 'Educated people proportion',
    'ITI * INI '
]


def load_TOD(workday=True):
    """
    Q: 这个数据去掉时间，直接做回归好像没什么意义。因为每个站点都存在多条数据，相当于直接引入了噪声

    Args:
        workday ():

    Returns:

    """
    if workday:
        with path('georegression.test.data', 'TOD_workday.csv') as filepath:
            df = pd.read_csv(filepath)
    else:
        with path('georegression.test.data', 'TOD_weekend.csv') as filepath:
            df = pd.read_csv(filepath)

    # Store and Drop Predictive Value
    y = df['Ridership'].values
    df = df.drop(columns=['Ridership'])

    # Specify Space and Temporal Feature
    xy_vector = df[['station.X', 'station.Y']].values
    time = df['Time'].values.reshape(-1, 1)

    # Select specific field
    if workday:
        df = df.loc[:, TOD_workday_field_list]
    else:
        df = df.loc[:, TOD_weekend_field_list]
    X = df.values

    # Do Not split the dataset.
    # No necessary for local model and OLS. Use Out-of-bag error for random forest.
    return X, y, xy_vector, time


def load_HP():
    with path('georegression.test.data', 'HousePrice_Shanghai.csv') as filepath:
        df = pd.read_csv(filepath)

    # Reorder the columns
    last_cols = ['Lon84', 'Lat84', 'Time', 'Rent', 'Price']
    new_cols = [col for col in df.columns if col not in last_cols] + last_cols
    df = df[new_cols]

    # Sample only one entity for multiple time slice.
    df = df.sample(frac=1, random_state=random_state).drop_duplicates(subset=new_cols[:-3])

    # Store and Drop Predictive Value
    y = df['Rent'].values
    df = df.drop(columns=['Rent', 'Price'])

    # Specify Space and Temporal Feature
    xy_vector = df[['Lon84', 'Lat84']].values
    time = df['Time'].values.reshape(-1, 1)
    df = df.drop(columns=['Lon84', 'Lat84', 'Time', 'ID', 'FID'])

    X = df.values

    return X, y, xy_vector, time


def load_ESI():
    with path('georegression.test.data', 'EcosystemServicesIndicator.csv') as filepath:
        df = pd.read_csv(filepath)

    # Use subset data.
    # choice_index = np.random.choice(42410, 1000, replace=False)
    # df = df[df['Id'].isin(choice_index)]

    # Specify Space and Temporal Feature
    xy_vector = df[['经度', '纬度']].values
    time = df['年份'].values.reshape(-1, 1)
    df = df.drop(columns=['经度', '纬度', '年份'])

    df = df[['生境', '养分']]

    # Store and Drop Predictive Value
    y_true = df['生境'].values
    df = df.drop(columns=['生境'])

    X = df.values

    return X, y_true, xy_vector, time

    pass


if __name__ == '__main__':
    load_TOD()


from time import time

import dask
import dask.array as da
import dask_distance
import numpy as np
from dask.distributed import Client, LocalCluster
from dask.graph_manipulation import wait_on
from distributed import get_task_stream
from scipy import sparse

from georegression.weight_matrix import weight_matrix_from_distance


def generate_distance_matrix(size: int= 100, rechunk=True):
    if size > 40000:
        chunk_size = 4000
    else:
        chunk_size = size / 10
    points = da.random.random((size, 2), chunks=(chunk_size, 2))
    distance_matrix = dask_distance.cdist(points,points,"euclidean")
    if rechunk:
        distance_matrix = distance_matrix.rechunk({0: 'auto', 1: -1})
    return distance_matrix


def test_dask_inner_graph():
    distance_matrix = generate_distance_matrix()
    weight_matrix = weight_matrix_from_distance([distance_matrix], "bisquare", neighbour_count=0.1)

    print(
        weight_matrix.map_blocks(sparse.coo_matrix).compute()
    )

    print()


def test_quantile_speed_up_1():
    spatial_distance_matrix = da.from_zarr("F://dask//spatial_distance_matrix.zarr")
    # spatial_distance_matrix = wait_on(spatial_distance_matrix)
    # temporal_distance_matrix = da.from_zarr("F://dask//temporal_distance_matrix.zarr")
    # temporal_distance_matrix = wait_on(temporal_distance_matrix)

    spatial_distance_matrix_sorted = da.from_zarr("F://dask//spatial_distance_matrix_sorted.zarr")
    # spatial_distance_matrix_sorted = wait_on(spatial_distance_matrix_sorted)
    # temporal_distance_matrix_sorted = da.from_zarr("F://dask//temporal_distance_matrix_sorted.zarr")
    # temporal_distance_matrix_sorted = wait_on(temporal_distance_matrix_sorted)

    t1 = time()
    result = weight_matrix_from_distance([spatial_distance_matrix], "bisquare", neighbour_count=0.05, distance_matrices_sorted=[spatial_distance_matrix_sorted])
    t2 = time()
    print(t2 - t1)
    
    result_sparse = result.map_blocks(sparse.coo_matrix)
    
    t3 = time()
    print(result_sparse.compute())
    t4 = time()
    print(t4 - t3)
    # 224.76215386390686

def test_quantile_speed_up_2():
    spatial_distance_matrix = da.from_zarr("F://dask//spatial_distance_matrix.zarr")

    spatial_distance_matrix_sorted = da.from_zarr(
        "F://dask//spatial_distance_matrix_sorted.zarr"
    )

    t1 = time()
    result = weight_matrix_from_distance(
        [spatial_distance_matrix],
        "bisquare",
        neighbour_count=0.05,
    )
    t2 = time()
    print(t2 - t1)

    result_sparse = result.map_blocks(sparse.coo_matrix)

    t3 = time()
    print(result_sparse.compute())
    t4 = time()
    print(t4 - t3)
    # 264.519348859787


def test_distance_optimization_speed_up():
    spatial_distance_matrix = da.from_zarr("F://dask//spatial_distance_matrix.zarr")
    spatial_distance_matrix_sorted = da.from_zarr(
        "F://dask//spatial_distance_matrix_sorted.zarr"
    )
    spatial_distance_matrix = wait_on(spatial_distance_matrix)
    spatial_distance_matrix_sorted = wait_on(spatial_distance_matrix_sorted)

    t1 = time()
    result = weight_matrix_from_distance(
        [spatial_distance_matrix],
        "bisquare",
        neighbour_count=0.05,
        distance_matrices_sorted=[spatial_distance_matrix_sorted],
    )
    t2 = time()
    print(t2 - t1)

    result_sparse = result.map_blocks(sparse.coo_matrix)

    t3 = time()
    print(result_sparse.compute())
    t4 = time()
    print(t4 - t3)
    # 228.34707760810852


def test_dask_compatiblity():
    distance_matrix = generate_distance_matrix(100000)
    distance_matrix = wait_on(distance_matrix)

    t1 = time()
    result = weight_matrix_from_distance([distance_matrix], "bisquare", neighbour_count=0.1)
    t2 = time()
    print(t2 - t1)

    result_sparse = result.map_blocks(sparse.coo_matrix)

    t3 = time()
    print(result_sparse.compute())
    t4 = time()
    print(t4 - t3)


def test_dask_map_block_valid():
    distance_matrix = wait_on(generate_distance_matrix())

    t1 = time()

    percentile = distance_matrix.map_blocks(
        np.percentile,
        50,
        axis=1,
        keepdims=True,
        drop_axis=1,
        # Specifying chunk size makes size error. Last chunk is smaller than the rest. auto should be used?
        # chunks=(distance_matrix.chunksize[0]),
    )

    print(percentile.shape, percentile.compute())
    t2 = time()
    print(t2 - t1)


def test_dask_reduction_valid():
    # 57.298909187316895 for rechunk, 40.120853900909424 for no rechunk
    # Rechunking make it slower for small data. But save memory for large data.
    distance_matrix = wait_on(generate_distance_matrix())

    t1 = time()

    def chunk_function(x, axis, keepdims):
        """
        Do the identical operation on a chunk of the data to pass to the aggregate function.
        """
        return x

    def aggregate_function(x, axis, keepdims):
        """
        Do the percentile operation on the aggregated (actually identity) data.
        """

        # Pre-call for dimensional checking by dask.
        if x.shape == (0, 0):
            return np.array([])
        return np.percentile(x, 99, axis=axis, keepdims=keepdims)

    percentile = da.reduction(
        distance_matrix,
        chunk_function,
        aggregate_function,
        axis=1,
        dtype=np.float64,
    )

    print(percentile.shape, percentile.compute())
    t2 = time()
    print(t2 - t1)


if __name__ == "__main__":
    # Set config of "distributed.comm.retry.count"
    dask.config.set({"distributed.comm.retry.count": 10})
    dask.config.set({"distributed.comm.timeouts.connect": 30})

    dask.config.get("distributed.worker.memory.target")
    dask.config.get("distributed.worker.memory.spill")
    dask.config.get("distributed.worker.memory.pause")
    dask.config.get("distributed.worker.memory.max-spill")
    # dask.config.set({"distributed.worker.memory.pause": 0.5})
    dask.config.set({"distributed.worker.memory.terminate": False})

    # create local cluster and start distributed scheduler.
    cluster = LocalCluster(
        local_directory="F:/dask",
        n_workers=4,
        memory_limit="24GiB",
    )
    client = Client(cluster)
    print(client.dashboard_link)

    with get_task_stream(plot="save", filename="task-stream.html") as ts:
        # test_dask_inner_graph()
        # test_quantile_speed_up_1()
        # test_quantile_speed_up_2()
        test_distance_optimization_speed_up()
        # test_dask_compatiblity()
        # test_dask_map_block_valid()


    client.profile(filename="dask-profile.html")


from georegression.distance_utils import (
    euclidean_distance_matrix,
    calculate_distance_one_to_many,
)
import numpy as np
from time import time
from scipy.spatial.distance import pdist, cdist
from scipy.spatial import distance_matrix
from sklearn.metrics import pairwise_distances

X = np.random.random((30000, 2))


def one_loop_version(X, Y):
    m = X.shape[0]
    n = Y.shape[0]
    dist = np.empty((m, n))
    for i in range(m):
        dist[i, :] = np.sqrt(np.sum((X[i] - Y) ** 2, axis=1))


def test_distance_matrix():
    t1 = time()
    euclidean_distance_matrix(X, X)
    t2 = time()
    for x in X:
        calculate_distance_one_to_many(x, X, "euclidean")
    t3 = time()
    pdist(X)
    t4 = time()
    cdist(X, X)
    t5 = time()
    one_loop_version(X, X)
    t6 = time()
    distance_matrix(X, X)
    t7 = time()
    pairwise_distances(X, X)
    t8 = time()
    pairwise_distances(X, X, n_jobs=-1)
    t9 = time()

    print(t2 - t1, t3 - t2, t4 - t3, t5 - t4, t6 - t5, t7 - t6, t8 - t7, t9 - t8)
    # For (30000,2):
    # INTEL I7 8700, PYTHON 3.7
    # 65.28857350349426 12.295624017715454 2.637856960296631 4.8066534996032715 14.165263652801514 20.550457000732422
    # AMD R9 7950X, PYTHON 3.11
    # 8.726486682891846 9.398136377334595 1.1142053604125977 2.3201327323913574 10.608065605163574 13.829639911651611 5.806329011917114


def matrix_size(dtype=None):
    # Record the time elapsed for creating the distance matrix.
    t1 = time()
    X = np.random.random((30000, 30000)).astype(dtype)
    t2 = time()
    print(t2 - t1)

    # Calculate the memory usage of the matrix. Print in MB.
    print(X.nbytes / 1024**2, "MB")

    # Calculate the time elapsed for calculating the distance matrix.
    X = np.random.random((30000, 2)).astype(dtype)
    t1 = time()
    pdist(X)
    t2 = time()
    pairwise_distances(X, X, n_jobs=-1)
    t3 = time()
    print(t2 - t1, t3 - t2)


if __name__ == "__main__":
    # test_distance_matrix()
    matrix_size(np.float16)


import dask
import dask.array as da
from distributed import LocalCluster, Client
import numpy as np

from georegression.distance_utils import _distance_matrix, _distance_matrices
from georegression.kernel import adaptive_bandwidth
from georegression.weight_matrix import weight_matrix_from_distance
from scipy import sparse




def test_distance_matrix_using_dask():
    dask.config.set({"distributed.comm.retry.count": 10})
    dask.config.set({"distributed.comm.timeouts.connect": 30})
    dask.config.set({"distributed.worker.memory.terminate": False})
    
    cluster = LocalCluster(local_directory="F://dask")
    client = Client(cluster)
    print(client.dashboard_link)

    count = 50000

    _distance_matrices(
        [da.from_array(np.random.random((count, 2)), chunks=(4000, 2))],
        [da.from_array(np.random.random((count, 2)), chunks=(4000, 2))],
        use_dask=True,
        cache_sort=True,
        filepath="F://test_distance_matrix",
        overwrite=True,
    )


def test_weight_matrix_using_sorted_distance_matrix():
    dask.config.set({"distributed.comm.retry.count": 10})
    dask.config.set({"distributed.comm.timeouts.connect": 30})
    dask.config.set({"distributed.worker.memory.terminate": False})
    
    cluster = LocalCluster(local_directory="F://dask")
    client = Client(cluster)
    print(client.dashboard_link)

    distance_matrix = da.from_zarr("F://dask//test_distance_matrix.zarr")
    distance_matrix_sorted = da.from_zarr("F://dask//test_distance_matrix_sorted.zarr")

    # bandwidth = adaptive_bandwidth(distance_matrix_sorted, 2)
    # print(bandwidth.compute())

    weight_matrix = weight_matrix_from_distance(
        [distance_matrix],
        "bisquare", neighbour_count=0.01,
        distance_matrices_sorted=[distance_matrix_sorted]
    )
    weight_matrix_sparse = weight_matrix.map_blocks(sparse.coo_matrix)
    print(weight_matrix_sparse.compute())


def test_dask_client():
    dask.config.set({"distributed.comm.retry.count": 10})
    dask.config.set({"distributed.comm.timeouts.connect": 30})
    dask.config.set({"distributed.worker.memory.terminate": False})
    
    cluster = LocalCluster(local_directory=kwargs.get("local_directory", None))
    client = Client(cluster)
    print(client.dashboard_link)

    Client.get()


if __name__ == "__main__":
    # test_distance_matrix_using_dask()
    test_weight_matrix_using_sorted_distance_matrix()

    pass


from time import time

import numpy as np
from numba import njit, prange
# import dpnp as np
# import dpnp.linalg
# import numba_dpex as dpex

# @dpex.dpjit
@njit(parallel=True)
def loop_parallel_inner(iteration_count=1000):
    for i in prange(iteration_count):
        # X = np.identity(100)
        X = np.random.random((100, 100))
        y = np.random.random((100, 1))
        # np.linalg.inv(X)
        # np.linalg.svd(X)

        # augmentation trick to ridge regression
        alpha = 1.0
        A = np.identity(X.shape[1], dtype=X.dtype)

        X_aug = np.vstack((X, np.sqrt(alpha) * A))
        y_aug = np.vstack((y, np.zeros((A.shape[0], 1), dtype=y.dtype)))

        # np.linalg.solve(X, y)
        # np.linalg.solve(X_aug, y_aug)
        np.linalg.lstsq(X_aug, y_aug)


if __name__ == '__main__':
    loop_parallel_inner(1)
    t1 = time()
    loop_parallel_inner(100000)
    t2 = time()
    print(t2 - t1)

    # 128.22677397727966 withouth dpjit

from numba import njit, prange
from numba.typed import List
import numpy as np


@njit()
def test_stack():
    array = np.ones((2, 3))
    list_of_array = [array] * 10
    np.stack(list_of_array)


@njit()
def test_list_stack(i, array_to_be_stacked):
    shape = (i,) + array_to_be_stacked.shape
    list_of_array = [array_to_be_stacked] * i
    stacked_array = np.empty(shape)
    for j in prange(i):
        stacked_array[j] = list_of_array[j]
    return stacked_array


@njit()
def stack(list_of_array):
    shape = (len(list_of_array),) + list_of_array[0].shape
    stacked_array = np.empty(shape)
    for j in prange(len(list_of_array)):
        stacked_array[j] = list_of_array[j]
    return stacked_array

if __name__ == "__main__":
    test_stack()

    test_list_stack(10, np.ones((2, 3)))

    # Note that you have to use typed list provided by numba here.
    typed_list = List()
    [typed_list.append(np.ones((2, 3))) for _ in range(10)]
    stacked = stack(typed_list)
    print(stacked.shape)
    print(stacked)


import time

import numpy as np
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV


def test_regression():
    neighbour_count = 500
    X = np.random.random((neighbour_count, neighbour_count))
    y = np.random.random(neighbour_count)
    estimator_count = 100

    t1 = time.time()

    for i in range(estimator_count):
        e = LinearRegression()
        e.fit(X, y)

    t2 = time.time()

    for i in range(estimator_count):
        e = Ridge()
        e.fit(X, y)

    t3 = time.time()

    for i in range(estimator_count):
        e = RidgeCV()
        e.fit(X, y)

    t4 = time.time()

    print()
    print(t2 - t1, t3 - t2, t4 - t3)
    # 2.859616994857788 0.6065783500671387 3.498260974884033
    # neighbour_count = 500 estimator_count = 100 AMD-7950X 10.886837720870972 0.49491024017333984 6.453542470932007


def test_solver():
    neighbour_count = 1000
    X = np.random.random((neighbour_count, neighbour_count))
    y = np.random.random(neighbour_count)
    estimator_count = 100

    t1 = time.time()

    for i in range(estimator_count):
        e = Ridge(10)
        e.fit(X, y)

    t2 = time.time()

    for i in range(estimator_count):
        e = Ridge(10, solver="svd")
        e.fit(X, y)

    t3 = time.time()

    for i in range(estimator_count):
        e = Ridge(10, solver="cholesky")
        e.fit(X, y)

    t4 = time.time()

    for i in range(estimator_count):
        e = Ridge(10, solver="sparse_cg")
        e.fit(X, y)

    t5 = time.time()

    for i in range(estimator_count):
        e = Ridge(10, solver="lsqr")
        e.fit(X, y)

    t6 = time.time()

    for i in range(estimator_count):
        e = Ridge(10, solver="sag")
        e.fit(X, y)

    t7 = time.time()

    print()
    print(t2 - t1, t3 - t2, t4 - t3, t5 - t4, t6 - t5, t7 - t6)
    # neighbour = 500 0.5866434574127197 4.114291429519653 0.6661701202392578 0.4684324264526367 0.3340754508972168 9.437297821044922
    # neighbour = 1000 2.821377992630005 21.683172464370728 2.867004156112671 0.9687221050262451 0.9672167301177979 64.40824103355408


def test_alpha():
    neighbour_count = 1000
    X = np.random.random((neighbour_count, neighbour_count))
    y = np.random.random(neighbour_count)
    estimator_count = 100

    t1 = time.time()

    for i in range(estimator_count):
        e = Ridge(0)
        e.fit(X, y)

    t2 = time.time()

    for i in range(estimator_count):
        e = Ridge(0.1)
        e.fit(X, y)

    t3 = time.time()

    for i in range(estimator_count):
        e = Ridge(10)
        e.fit(X, y)

    t4 = time.time()

    print()
    print(t2 - t1, t3 - t2, t4 - t3)
    # neighbour = 1000 57.08750033378601 1.9294869899749756 1.918003797531128


if __name__ == "__main__":
    test_regression()
    # test_solver()
    # test_alpha()


from time import time

import numpy as np
from numba import jit, njit, prange
from sklearn.linear_model import Ridge


def loop_python(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    for i in range(iteration_count):
        estimator = Ridge(1)
        estimator.fit(X, y)


@jit(forceobj=True, looplift=True)
def loop_jitting(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    for i in range(iteration_count):
        estimator = Ridge(1)
        estimator.fit(X, y)


@njit()
def loop_numba(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    for i in range(iteration_count):
        ridge_fit(X, y)


@njit(parallel=True)
def loop_paralle(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    # Stuck when X, y is passed to ridge_fit. Everything is fine if X, y are generated inside ridge_fit.

    for i in prange(iteration_count):
        ridge_fit(X, y)


# @njit(parallel=True)
def loop_paralle_lstsq(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    for i in prange(iteration_count):
        alpha = 1.0

        # Center the data to make the intercept term zero
        X_offset = mean(X, axis=0)
        y_offset = mean(y, axis=0)
        X_center = X - X_offset
        y_center = y - y_offset

        dimension = X_center.shape[1]
        A = np.identity(dimension)

        X_aug = np.vstack((X_center, np.sqrt(alpha) * A))
        y_aug = np.vstack((y_center, np.zeros((A.shape[0], 1), dtype=y.dtype)))

        coef = np.linalg.lstsq(X_aug, y_aug)[0]
        intercept = y_offset - np.dot(X_offset, coef)


@njit(parallel=True)
def loop_parallel_chol(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    for i in prange(iteration_count):
        ridge_cholesky(X, y)

        # alpha = 1.0
        #
        # # Center the data to make the intercept term zero
        # X_offset = mean(X, axis=0)
        # y_offset = mean(y, axis=0)
        # X_center = X - X_offset
        # y_center = y - y_offset
        #
        # A = np.dot(X_center.T, X_center)
        # Xy = np.dot(X_center.T, y_center)
        #
        # A = A + alpha * np.eye(X.shape[1])
        #
        # coef = np.linalg.solve(A, Xy)
        # intercept = y_offset - np.dot(X_offset, coef)

@njit(parallel=True)
def loop_parallel_inner(iteration_count=1000):
    X = np.random.random((100, 100))
    y = np.random.random((100, 1))

    # coef_list = []
    # intercept_list = []

    for i in prange(iteration_count):
        alpha = 1.0
        
        # Center the data to make the intercept term zero
        # X_offset = mean(X, axis=0)
        # y_offset = mean(y, axis=0)
        # X_center = X - X_offset
        # y_center = y - y_offset

        X_center = X
        y_center = y
        
        dimension = X_center.shape[1]
        A = np.identity(dimension)
        A_biased = alpha * A

        temp = X_center.T.dot(X_center) + A_biased

        np.linalg.inv(
            temp
        )

        # coef = np.linalg.inv(
        #     X_center.T.dot(X_center) + A_biased
        # ).dot(X_center.T).dot(y_center)
        # intercept = y_offset - np.dot(X_offset, coef)
        
        # coef_list.append(coef)
        # intercept_list.append(intercept)


@njit()
def mean(x, axis, weight):
    weight = weight.reshape((-1, 1))
    x = x * weight
    weight = weight.reshape((1, -1))

    return np.sum(x, axis) / weight.sum()


@njit()
def ridge_fit(X, y):
    alpha = 1.0

    # Center the data to make the intercept term zero
    X_offset = mean(X, axis=0)
    y_offset = mean(y, axis=0)
    X_center = X - X_offset
    y_center = y - y_offset

    dimension = X_center.shape[1]
    A = np.identity(dimension)
    A_biased = alpha * A

    coef = np.linalg.inv(X_center.T.dot(X_center) + A_biased).dot(X_center.T).dot(y_center)
    intercept = y_offset - np.dot(X_offset, coef)

    return coef, intercept


@njit()
def rigde_lstsq(X, y):
    alpha = 1.0

    # Center the data to make the intercept term zero
    X_offset = mean(X, axis=0)
    y_offset = mean(y, axis=0)
    X_center = X - X_offset
    y_center = y - y_offset

    dimension = X_center.shape[1]
    A = np.identity(dimension)

    X_aug = np.vstack((X_center, np.sqrt(alpha) * A))
    y_aug = np.vstack((y_center, np.zeros((A.shape[0], 1), dtype=y.dtype)))

    coef = np.linalg.lstsq(X_aug, y_aug)[0]
    intercept = y_offset - np.dot(X_offset, coef)

    return coef, intercept

@njit()
def ridge_cholesky(X, y, weights=None):
    alpha = 1.0

    # Center the data to make the intercept term zero

    # TODO: Weight
    X_offset = mean(X, axis=0, weight=weights)
    y_offset = mean(y, axis=0, weight=weights)

    X_center = X - X_offset
    y_center = y - y_offset

    if weights is not None:
        weights_sqrt = np.sqrt(weights)
        for index, weight in enumerate(weights_sqrt):
            X_center[index] *= weight
            y_center[index] *= weight

    A = np.dot(X_center.T, X_center)
    Xy = np.dot(X_center.T, y_center)

    A = A + alpha * np.eye(X.shape[1])

    coef = np.linalg.solve(A, Xy)
    intercept = y_offset - np.dot(X_offset, coef)

    return coef, intercept


def test_ridge_work():
    X = np.random.random((1000, 100))
    y = np.random.random((1000, 1))
    weight = np.random.random((1000, ))

    # ridge_fit(X, y)
    t1 = time()
    # coef, intercept = ridge_fit(X, y)
    t2 = time()
    # print(coef, intercept)

    t3 = time()
    estimator = Ridge(1.0).fit(X, y, sample_weight=weight)
    t4 = time()
    print(estimator.coef_, estimator.intercept_)

    # rigde_lstsq(X, y)
    t5 = time()
    # coef, intercept = rigde_lstsq(X, y)
    t6 = time()
    # print(coef, intercept)

    coef, intercept = ridge_cholesky(X, y, weight)
    t7 = time()
    coef, intercept = ridge_cholesky(X, y, weight)
    t8 = time()
    print(coef, intercept)


    print(t2 - t1)
    print(t4 - t3)
    print(t6 - t5)
    print(t8 - t7)


def test_loop():

    t1 = time()
    loop_python(100000)
    t2 = time()
    print(t2 - t1)

    # loop_jitting()
    t1 = time()
    # loop_jitting()
    t2 = time()
    print(t2 - t1)

    # loop_numba()
    t1 = time()
    # loop_numba()
    t2 = time()
    print(t2 - t1)

    # loop_paralle()
    t1 = time()
    # loop_paralle()
    t2 = time()
    print(t2 - t1)

    # loop_parallel_inner()
    t1 = time()
    # loop_parallel_inner()
    t2 = time()
    print(t2 - t1)

    # loop_paralle_lstsq(1)
    t1 = time()
    # loop_paralle_lstsq(10000)
    t2 = time()
    print(t2 - t1)

    loop_parallel_chol(1)
    t1 = time()
    loop_parallel_chol(100000)
    t2 = time()
    print(t2 - t1)

if __name__ == "__main__":
    test_ridge_work()

    # 41 for intel extension
    # 39 for original sklearn
    # test_loop()


from time import time

import numpy as np
from joblib import Parallel, delayed
from sklearn.linear_model import Ridge


def test_parallel_regression():
    # Generate stacking data
    neighbour_count = 100
    X = np.random.random((neighbour_count, neighbour_count))
    y = np.random.random(neighbour_count)
    estimator_count = 20000

    def fit_stacking():
        return Ridge().fit(X, y)

    t_start = time()
    job_list = [delayed(fit_stacking)() for i in range(estimator_count)]
    Parallel(n_jobs=-1)(job_list)
    t_end = time()

    print(t_end - t_start)
    # neighbour_count = 500 estimator_count = 100 1.5465185642242432
    # neighbour_count = 1000 estimator_count = 1000 7.498879909515381
    # neighbour_count = 20000 estimator_count = 500 21.9149112701416
    # neighbour_count = 100 estimator_count = 20000 1.5993006229400635


    t_start = time()
    for i in range(estimator_count):
        e = Ridge()
        e.fit(X, y)
    t_end = time()

    print(t_end - t_start)
    # neighbour_count = 1000 estimator_count = 1000 19.928532361984253
    # neighbour_count = 20000 estimator_count = 500 97.64331030845642
    # neighbour_count = 100 estimator_count = 20000 10.334396123886108




if __name__ == "__main__":
    test_parallel_regression()


import time
from functools import reduce

import numpy as np
from numba import njit,prange, boolean
from scipy.sparse import csr_array
from scipy.spatial.distance import cdist


def second_neighbour_matrix(neighbour_csr: csr_array):
    second_neighbour_indices_list = []
    for row_index in range(neighbour_csr.shape[0]):
        column_indices = neighbour_csr.indices[neighbour_csr.indptr[row_index]:neighbour_csr.indptr[row_index + 1]]
        second_neighbour_final = []
        for column_index in column_indices:
            second_neighbour_indices = neighbour_csr.indices[
                                       neighbour_csr.indptr[column_index]: neighbour_csr.indptr[column_index + 1]
                                       ]
            # second_neighbour_final = np.union1d(second_neighbour_final, second_neighbour_indices)
            second_neighbour_final.append(second_neighbour_indices)
        if len(second_neighbour_final) == 0:
            r = np.array([])
        else:
            r = reduce(np.union1d, second_neighbour_final)
        second_neighbour_indices_list.append(r)

    return second_neighbour_indices_list


@njit()
def second_neighbour_matrix_numba(indptr, indices):
    """
    TODO: More deep understanding of the numba is required.

    TODO: Add reduce parallel
    TODO: Add loop parallel

    Args:
        indptr ():
        indices ():

    Returns:

    """

    N = len(indptr) - 1
    # Numba type instead of numpy type should be provided here.
    second_neighbour_matrix = np.zeros((N, N), dtype=boolean)
    for row_index in range(N):
        neighbour_indices = indices[indptr[row_index]:indptr[row_index + 1]]
        second_neighbour_indices_union = np.zeros((N,))
        for neighbour_index in neighbour_indices:
            second_neighbour_indices = indices[
                                       indptr[neighbour_index]: indptr[neighbour_index + 1]
                                       ]
            for second_neighbour_index in second_neighbour_indices:
                second_neighbour_indices_union[second_neighbour_index] = True

        second_neighbour_matrix[row_index] = second_neighbour_indices_union

    return second_neighbour_matrix


@njit()
def second_neighbour_matrix_numba_2(indptr, indices):
    """
    Return in sparse format.
    """

    indices_list = []
    N = len(indptr) - 1
    for row_index in range(N):
        neighbour_indices = indices[indptr[row_index]:indptr[row_index + 1]]
        second_neighbour_indices_union = np.zeros((N,))
        for neighbour_index in neighbour_indices:
            second_neighbour_indices = indices[
                                       indptr[neighbour_index]: indptr[neighbour_index + 1]
                                       ]
            for second_neighbour_index in second_neighbour_indices:
                second_neighbour_indices_union[second_neighbour_index] = True

        second_neighbour_indices_union = np.nonzero(second_neighbour_indices_union)[0]
        indices_list.append(second_neighbour_indices_union)

    return indices_list


@njit(parallel=True)
def second_neighbour_matrix_numba_loop(indptr, indices):
    """
    Return in sparse format.
    """

    N = len(indptr) - 1
    # Manually create the list with specified length to avoid parallel Mutating error.
    indices_list = [np.empty(0, dtype=np.int64)] * N
    for row_index in prange(N):
        neighbour_indices = indices[indptr[row_index]:indptr[row_index + 1]]
        second_neighbour_indices_union = np.zeros((N,))
        for neighbour_index in neighbour_indices:
            second_neighbour_indices = indices[
                                       indptr[neighbour_index]: indptr[neighbour_index + 1]
                                       ]
            for second_neighbour_index in second_neighbour_indices:
                second_neighbour_indices_union[second_neighbour_index] = True

        second_neighbour_indices_union = np.nonzero(second_neighbour_indices_union)[0]
        indices_list[row_index] = second_neighbour_indices_union

    return indices_list


@njit(parallel=True)
def second_neighbour_matrix_numba_loop_2(indptr, indices):
    """
    Return in sparse format.
    """

    N = len(indptr) - 1
    # Manually create the list with specified length to avoid parallel Mutating error.
    indices_list = [np.empty(0, dtype=np.int64)] * N
    for row_index in prange(N):
        neighbour_indices = indices[indptr[row_index]:indptr[row_index + 1]]
        second_neighbour_indices_union = np.zeros((N,))
        for neighbour_index in neighbour_indices:
            second_neighbour_indices = indices[
                                       indptr[neighbour_index]: indptr[neighbour_index + 1]
                                       ]

            # TODO: Consider using the np.union1d here? Unknown and variate length may cause performance issue.
            for second_neighbour_index in second_neighbour_indices:
                second_neighbour_indices_union[second_neighbour_index] = True

        second_neighbour_indices_union = np.nonzero(second_neighbour_indices_union)[0]
        indices_list[row_index] = second_neighbour_indices_union

    return indices_list

def test_second_neighbour_matrix():
    points = np.random.random((10000, 2))
    distance_matrix = cdist(points, points)
    neighbour_matrix = csr_array(distance_matrix > 0.8)

    t1 = time.time()
    # r = second_neighbour_matrix(neighbour_matrix)
    t2 = time.time()
    print(t2 - t1)

    # second_neighbour_matrix_numba(neighbour_matrix.indptr, neighbour_matrix.indices)
    t3 = time.time()
    # r = second_neighbour_matrix_numba(neighbour_matrix.indptr, neighbour_matrix.indices)
    t4 = time.time()
    print(t4 - t3)

    # second_neighbour_matrix_numba_2(neighbour_matrix.indptr, neighbour_matrix.indices)
    t5 = time.time()
    # r = second_neighbour_matrix_numba_2(neighbour_matrix.indptr, neighbour_matrix.indices)
    t6 = time.time()
    print(t6 - t5)

    second_neighbour_matrix_numba_loop(neighbour_matrix.indptr, neighbour_matrix.indices)
    t7 = time.time()
    r = second_neighbour_matrix_numba_loop(neighbour_matrix.indptr, neighbour_matrix.indices)
    t8 = time.time()
    print(t8 - t7)

    # second_neighbour_matrix_numba_loop_2(neighbour_matrix.indptr, neighbour_matrix.indices)
    t9 = time.time()
    # r = second_neighbour_matrix_numba_loop_2(neighbour_matrix.indptr, neighbour_matrix.indices)
    t10 = time.time()
    print(t10 - t9)

    print()


def bool_type():
    m = np.random.random((100, 100)) > 0.5
    s = csr_array(m)
    # TODO: bool type is not fully compressed, as there is duplicated True value in the data array.
    print()


if __name__ == '__main__':
    # bool_type()
    test_second_neighbour_matrix()
    pass


import time

import numpy as np
from numba import njit, prange


@njit()
def input_indicator_1(neighbour_matrix):
    indicator_matrix = np.empty_like(neighbour_matrix)
    for i in prange(neighbour_matrix.shape[0]):
        indicator_matrix[i] = np.sum(neighbour_matrix[neighbour_matrix[i]], axis=0)
    return indicator_matrix


@njit()
def input_indicator_2(neighbour_matrix):
    indicator_matrix = []
    for i in prange(neighbour_matrix.shape[0]):
        indicator_matrix.append(np.sum(neighbour_matrix[neighbour_matrix[i]], axis=0))
    return indicator_matrix


@njit()
def input_indicator_3(neighbour_matrix):
    indicator_matrix = []
    for row_neighbour in neighbour_matrix:
        indicator_matrix.append(np.sum(neighbour_matrix[row_neighbour], axis=0))
    return indicator_matrix


def input_indicator_4(neighbour_matrix):
    indicator_matrix = []
    for i in prange(neighbour_matrix.shape[0]):
        indicator_matrix.append(np.sum(neighbour_matrix[neighbour_matrix[i]], axis=0))
    return indicator_matrix


def input_indicator_5(neighbour_matrix):
    indicator_matrix = []
    for row_neighbour in neighbour_matrix:
        indicator_matrix.append(
            neighbour_matrix[row_neighbour].sum(axis=0)
        )
    indicator_matrix = np.array(indicator_matrix)
    return indicator_matrix


def test_weight_indicator():
    estimator_count = 1000

    weight_matrix = np.random.random((estimator_count, estimator_count)) - 0.5
    neighbour_matrix = weight_matrix > 0
    pre_data = (np.random.random((10, 10)) - 0.5) > 0

    # Warm-up for pre-compile
    t0 = time.time()
    input_indicator_1(pre_data)
    input_indicator_2(pre_data)
    input_indicator_3(pre_data)

    t1 = time.time()
    input_indicator_1(neighbour_matrix)
    t2 = time.time()
    input_indicator_2(neighbour_matrix)
    t3 = time.time()
    input_indicator_3(neighbour_matrix)
    t4 = time.time()
    input_indicator_4(neighbour_matrix)
    t5 = time.time()
    input_indicator_5(neighbour_matrix)
    t6 = time.time()

    print()
    print(t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4, t6 - t5)
    # 1.3265900611877441
    # 0.0700376033782959 0.07053327560424805 0.07101583480834961 0.4606480598449707 0.4606757164001465


import time

import numpy as np
from numba import njit, prange
from scipy.sparse import csr_matrix, csr_array, lil_array
from scipy.spatial.distance import cdist


@njit()
def second_order_neighbour(neighbour_matrix: csr_matrix):
    second_order_matrix = np.empty_like(neighbour_matrix)
    for i in prange(neighbour_matrix.shape[0]):
        second_order_matrix[i] = np.sum(neighbour_matrix[neighbour_matrix[i]], axis=0)
    return second_order_matrix

def second_order_neighbour_sparse(neighbour_matrix: csr_matrix):
    second_order_matrix = lil_array((neighbour_matrix.shape[0], neighbour_matrix.shape[1]), dtype=bool)
    for i in prange(neighbour_matrix.shape[0]):
        second_order_matrix[i] = np.sum(neighbour_matrix[neighbour_matrix[[i], :].nonzero()[1], :], axis=0) > 0

    return second_order_matrix

def test_second_order_neighbour():
    points = np.random.random((10000, 2))
    distance_matrix = cdist(points, points)
    neighbour_matrix = distance_matrix > 0.95

    m = neighbour_matrix
    s = csr_array(m)
    t1 = time.time()
    # r = second_order_neighbour_sparse(s)
    t2 = time.time()
    print(t2 - t1)

    second_order_neighbour(m)

    t3 = time.time()
    second_order_neighbour(m)
    t4 = time.time()
    print(t4 - t3)

from sklearnex import patch_sklearn

# patch_sklearn()

from georegression.weight_model import WeightModel
from time import time as t

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import RidgeCV
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

X, y_true, xy_vector, time = load_HP()


def test_performance():
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.1

    estimator = WeightModel(
        RandomForestRegressor(n_estimators=50, max_features=1.0),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
    )
    t2 = t()
    estimator.fit(X, y_true, [xy_vector, time])
    t3 = t()
    print(f"Time taken to fit: {t3 - t2}")
    print(estimator.llocv_score_)
    # * neighbour_count = 0.01 n_estimators=50 12.85382342338562
    # * neighbour_count = 0.1 n_estimators=50 21.567977905273438 -1095426075.984387
    # neighbour_count = 0.01 n_estimators=50 19.778754949569702
    # neighbour_count = 0.1 n_estimators=50 38.774967432022095 0.8090267793183445
    # TODO: intel extension not working. Need to check the reason.


def test_performance_on_stacking():
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    neighbour_count = 0.1

    estimator = StackingWeightModel(
        ExtraTreeRegressor(splitter="random", max_depth=1),
        distance_measure,
        kernel_type,
        neighbour_count=neighbour_count,
        neighbour_leave_out_rate=0.1,
    )


if __name__ == "__main__":
    test_performance()


from scipy.sparse import csr_array
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.metrics import r2_score

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel

from time import time as t

X, y_true, xy_vector, time = load_HP()


def test_compatibility():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time], [xy_vector, time], "euclidean", "bisquare", None, None, 0.1, None
    )

    # Normal case
    t1 = t()
    estimator = WeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        "euclidean",
        "bisquare",
        neighbour_count=0.1
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)

    # Sparse case
    weight_matrix = csr_array(weight_matrix)

    t1 = t()
    estimator = WeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        "euclidean",
        "bisquare",
        neighbour_count=0.1
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)


def test_stacking_compatibility():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time],
        [xy_vector, time],
        "euclidean",
        "bisquare", None, None,
        0.1, None
    )

    # Normal case
    estimator = StackingWeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        neighbour_leave_out_rate=0.1,
        use_numba=False
    )
    # estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)

    # Sparse case
    weight_matrix = csr_array(weight_matrix)

    estimator = StackingWeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        neighbour_leave_out_rate=0.1,
        use_numba=False
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)

    print()

if __name__ == '__main__':
    # test_compatibility()
    test_stacking_compatibility()


from scipy.sparse import csr_array
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.metrics import r2_score

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel

from time import time as t

X, y_true, xy_vector, time = load_HP()

def test_compatibility():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time], [xy_vector, time], "euclidean", "bisquare", None, None, 0.1, None
    )
    weight_matrix = csr_array(weight_matrix)

    estimator = StackingWeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        neighbour_leave_out_rate=0.1,
        use_numba=True,
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)


if __name__ == '__main__':
    test_compatibility()

import numpy as np
from scipy.sparse import csr_matrix


def test_sparse_operation():
    arr = (np.random.random((10, 10)) - 0.8) > 0
    # TODO: Consider using lil_array for modifying the array.
    arr = csr_matrix(arr, dtype=bool)

    # Analogy of nonzere. Ref:https://numpy.org/devdocs/user/basics.indexing.html#boolean-array-indexing
    second_nei = np.sum(arr[arr[0].nonzero()[1]], axis=0) > 0
    print(second_nei)

    # Set value
    second_neighbour_matrix = csr_matrix((10, 10), dtype=bool)
    second_neighbour_matrix[0, second_nei.nonzero()[1]] = True
    print(second_neighbour_matrix)

    # Index array by sparse matrix
    np.random.random((10, 10))[second_neighbour_matrix[0].nonzero()]

import time

import numpy as np
from joblib.parallel import Parallel, delayed
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor


def test_tree_efficiency():
    """
    对于某局部模型，给定局部数据X
    Stacking：训练单个简单的树，对X个邻接的X进行预测。
    训练时间train*data(X)，预测时间predict*data(X)*number(X)。实际中预测的X有重复，应小于这个值。
    线性Stacking时间Inverse(X*X) X个样本，X个邻接提供的X个特征。

    RandomForest：训练n个完全树，进行平均。
    训练时间train*data(X)*number(n)

    Returns:

    """
    neighbour_count = 5000
    feature_count = 30
    X = np.random.random((neighbour_count, feature_count))
    X_bulk = np.concatenate([X] * neighbour_count, axis=0)
    y = np.random.random(neighbour_count)

    # Wamp-up

    tree = DecisionTreeRegressor()
    tree.fit(X, y)
    tree.predict(X)

    forest = RandomForestRegressor()
    forest.fit(X, y)

    t1 = time.time()

    tree = DecisionTreeRegressor(max_depth=2, splitter='random')
    tree.fit(X, y)

    t2 = time.time()

    # Predict in atomic. But not completely atomic.
    for i in range(neighbour_count):
        tree.predict(X)

    t3 = time.time()

    # Bulk is better for moderate amount of data. (100~500)^2. Degrade when come to 1000^2.
    # https://scikit-learn.org/0.15/modules/computational_performance.html
    # https://scikit-learn.org/0.15/auto_examples/applications/plot_prediction_latency.html#example-applications-plot-prediction-latency-py
    tree.predict(X_bulk)

    t4 = time.time()

    forest = RandomForestRegressor(n_estimators=100, n_jobs=-1)
    forest.fit(X, y)

    t5 = time.time()

    print()
    print(t2 - t1, t3 - t2, t4 - t3, t5 - t4)
    # neighbour = 100 0.0 0.00400090217590332 0.0009999275207519531 0.09102034568786621
    # neighbour = 500 0.0 0.02702188491821289 0.016988277435302734 0.12402749061584473
    # neighbour = 1000 0.0009999275207519531 0.06401467323303223 0.06501483917236328 0.1708080768585205
    # neighbour = 2000 0.0010001659393310547 0.18304109573364258 0.2760617733001709 0.3486180305480957
    # neighbour = 5000 0.003000497817993164 0.8315958976745605 3.4696438312530518 1.164928913116455


def fit_estimator(estimator, X, y):
    estimator.fit(X, y)
    return estimator


def estimators_predict(estimator_list, X):
    result_list = []
    for estimator in estimator_list:
        result_list.append(
            estimator.predict_by_weight(X)
        )
    return result_list


def estimators_parallel_predict(estimator_list, X):
    return Parallel(-1)(
        delayed(estimator_predict)(
            estimator, X
        )
        for estimator in estimator_list
    )


def estimator_predict(estimator, X):
    return estimator.predict_by_weight(X)


def estimator_predict_n_time(estimator, X, times):
    predict_result = []
    for i in range(times):
        predict_result.append(
            estimator.predict_by_weight(X)
        )
    return predict_result


def test_in_joblib():
    # 邻接数量
    neighbour_count = 100
    # 局部模型数量
    estimator_count = 100

    feature_count = 30
    X = np.random.random((neighbour_count, feature_count))
    X_bulk = np.concatenate([X] * neighbour_count, axis=0)
    y = np.random.random(neighbour_count)

    parallel = Parallel(-1)

    # Warm-up
    parallel([
        delayed(fit_estimator)(
            DecisionTreeRegressor(), X, y
        )
        for i in range(64)
    ])
    parallel([
        delayed(fit_estimator)(
            RandomForestRegressor(), X, y
        )
        for i in range(64)
    ])

    print('Tree fitting starts.')
    t1 = time.time()

    tree_list = parallel([
        delayed(fit_estimator)(
            DecisionTreeRegressor(max_depth=2, splitter='random'), X, y
        )
        for i in range(estimator_count)
    ])

    print('Prediction Starts.')
    t2 = time.time()

    # Really slow. Resource not fully used.

    # Only get the theoretical time when neighbour count equals estimator count.
    # i.e. tree list equals neighbour estimator list.
    tree_result = parallel([
        delayed(estimators_predict)(
            tree_list, X
        )
        for i in range(estimator_count)
    ])

    t3 = time.time()

    tree_result = parallel(
        delayed(estimator_predict_n_time)(
            estimator, X, neighbour_count
        )
        for estimator in tree_list
    )

    t4 = time.time()

    for estimator in tree_list:
        for i in range(neighbour_count):
            estimator.predict_by_weight(X)

    t5 = time.time()

    for estimator in tree_list:
        estimator.predict_by_weight(X_bulk)

    t6 = time.time()

    tree_result = parallel(
        delayed(estimator_predict)(
            estimator, X_bulk
        )
        for estimator in tree_list
    )

    t7 = time.time()

    split_num = 6
    tree_result = parallel(
        delayed(estimators_predict)(
            [tree_list[i] for i in split_index], X_bulk
        )
        for split_index in np.array_split(np.arange(estimator_count), split_num)
    )

    print('Prediction Ends.')
    t8 = time.time()

    forest_list = parallel([
        delayed(fit_estimator)(
            RandomForestRegressor(n_estimators=100, n_jobs=-1), X, y
        )
        for i in range(estimator_count)
    ])

    t9 = time.time()

    print()
    print(t2 - t1)
    print(t3 - t2, t4 - t3, t5 - t4, t6 - t5, t7 - t6, t8 - t7)
    print(t9 - t8)
    # neighbour = estimator = 100
    # 0.03500962257385254
    # 0.65926194190979 0.09602212905883789 0.3746037483215332 0.05401206016540527 0.17014646530151367 0.05001258850097656
    # 1.7216594219207764


if __name__ == '__main__':
    test_in_joblib()


from time import time as t

from scipy.sparse import csr_array
from sklearn.tree import ExtraTreeRegressor

from georegression.test.data import load_HP
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel

X, y_true, xy_vector, time = load_HP()


def test_performance_n_jobs():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time],
        [xy_vector, time],
        "euclidean",
        "bisquare",
        neighbour_count=0.1,
    )

    t1 = t()
    estimator = WeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        n_jobs=-1
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)

    weight_matrix = csr_array(weight_matrix)

    t1 = t()
    estimator = WeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        n_jobs=-1
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)

    # 5.687727212905884
    # 0.7720929346825464
    # 1.841477394104004
    # 0.7732306074996614

def test_performance_n_patches():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time],
        [xy_vector, time],
        "euclidean",
        "bisquare",
        neighbour_count=0.1,
    )
    
    t1 = t()
    estimator = WeightModel(ExtraTreeRegressor(max_depth=1, splitter="random"), n_patches=6)
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)
    
    weight_matrix = csr_array(weight_matrix)
    
    t1 = t()
    estimator = WeightModel(ExtraTreeRegressor(max_depth=1, splitter="random"), n_patches=6)
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    t2 = t()
    print(t2 - t1)
    print(estimator.llocv_score_)

    # 2.56437611579895
    # 0.7760204788183097
    # 0.9893820285797119
    # 0.7697055618854263

if __name__ == "__main__":
    test_performance_n_jobs()
    test_performance_n_patches()




import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor

X = np.random.normal(size=(100, 5))
y = np.random.normal(size=100)

df = pd.DataFrame(X, columns=["x1", "x2", "x3", "x4", "x5"])

estimator = RandomForestRegressor(n_estimators=100)
estimator.fit(X, y)


def PyALE():
    from PyALE import ale

    ale(df, estimator, ["x1"])


def alibiALE():
    from alibi.explainers import ALE
    from alibi.explainers import plot_ale

    def predict_fn(X):
        return estimator.predict(X)

    ale = ALE(predict_fn)
    exp = ale.explain(X)

    plot_ale(exp)


def ALEPython():
    from alepython import ale_plot

    # Plots ALE of feature 'cont' with Monte-Carlo replicas (default : 50).
    ale_plot(model, X_train, "cont", monte_carlo=True)


if __name__ == "__main__":
    alibiALE()


from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel
from georegression.visualize.ale import plot_ale

X, y, xy_vector, time = load_HP()
X = X[:200]
y = y[:200]
xy_vector = xy_vector[:200]
time = time[:200]

def test_ale():
    global X

    model = WeightModel(
        # LinearRegression(),
        RandomForestRegressor(n_estimators=10),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,
        cache_data=True, cache_estimator=True, n_jobs=1
    )

    X = X[:, -5:]
    model.fit(X, y, [xy_vector, time])
    for i in range(5):
        feature_index = i
        fval, ale = model.global_ALE(feature_index)
        plot_ale(fval, ale, X[:, feature_index])

    print()

if __name__ == '__main__':
    test_ale()


import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor

from georegression.local_ale import weighted_ale

X = np.random.normal(size=(100, 5))
y = np.random.normal(size=100)

df = pd.DataFrame(X, columns=["x1", "x2", "x3", "x4", "x5"])

estimator = RandomForestRegressor(n_estimators=100)
estimator.fit(X, y)


def test_weighted_ale():
    weighted_ale(X, 0, estimator.predict, weights=np.random.random(size=100))

if __name__ == "__main__":
    test_weighted_ale()


from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel
from georegression.visualize.ale import plot_ale

X, y, xy_vector, time = load_HP()


def test_ale():
    model = WeightModel(
        # LinearRegression(),
        RandomForestRegressor(n_estimators=10),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,
        cache_data=True, cache_estimator=True, n_jobs=-1
    )

    model.fit(X[:, -5:], y, [xy_vector, time])
    ale_list = model.local_ALE(0)

    for ale in ale_list:
        fval, ale = ale
        plot_ale(fval, ale)

    print()



if __name__ == '__main__':
    test_ale()


from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr
from sklearn.linear_model import LinearRegression
from statsmodels.stats.outliers_influence import variance_inflation_factor

from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

(X, y_true, xy_vector, time) = load_HP()

# TODO: Add labels
labels = np.arange(X.shape[1])


def test_spearman(threshold=1, draw=False):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(25, 20))
    corr = spearmanr(X).correlation

    # Ensure the correlation matrix is symmetric
    corr = (corr + corr.T) / 2
    np.fill_diagonal(corr, 1)

    # We convert the correlation matrix to a distance matrix before performing
    # hierarchical clustering using Ward's linkage.
    distance_matrix = 1 - np.abs(corr)
    dist_linkage = hierarchy.ward(squareform(distance_matrix))
    # dist_linkage = hierarchy.single(squareform(distance_matrix))
    dendro = hierarchy.dendrogram(
        dist_linkage, labels=labels, ax=ax1, leaf_rotation=90
    )
    dendro_idx = np.arange(0, len(dendro["ivl"]))

    ims = ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]], cmap='PiYG', vmin=-1, vmax=1)
    fig.colorbar(ims, ax=ax2)
    ax2.set_xticks(dendro_idx)
    ax2.set_yticks(dendro_idx)
    ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
    ax2.set_yticklabels(dendro["ivl"])
    fig.tight_layout()
    if draw:
        plt.show()
    plt.savefig(f'test_corr.png')
    plt.clf()

    cluster_ids = hierarchy.fcluster(dist_linkage, threshold, criterion="distance")
    cluster_id_to_feature_ids = defaultdict(list)
    for idx, cluster_id in enumerate(cluster_ids):
        cluster_id_to_feature_ids[cluster_id].append(idx)
    selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]

    return selected_features


def test_vif(threshold=10, output=False):
    remove_feature_list = []
    while True:
        select_feature_list = np.full(X.shape[1], True)
        select_feature_list[remove_feature_list] = False
        X_selected = X[:, select_feature_list]
        vif_list = [variance_inflation_factor(X_selected, i) for i in range(X_selected.shape[1])]

        # Inf
        # inf_index_list = np.nonzero(np.isinf(vif_list))
        # if inf_index_list[0].shape[0] != 0:
        #     remove_feature_list.append(
        #         np.where(select_feature_list)[inf_index_list[0][0]]
        #     )

        if max(vif_list) < threshold:
            if output:
                print(list(zip(vif_list, labels[select_feature_list])))
            break

        # Remove Max VIF Feature
        vif_index = vif_list.index(max(vif_list))
        remove_feature_list.append(
            np.where(select_feature_list)[0][vif_index]
        )

    return select_feature_list


def test_customer_vif():
    from sklearn.linear_model import LinearRegression
    estimator = LinearRegression()

    feature_index = 12

    estimator.fit(X[:, [*range(feature_index), *range(feature_index + 1, X.shape[1])]], X[:, feature_index])
    r2 = estimator.score(X[:, [*range(feature_index), *range(feature_index + 1, X.shape[1])]], X[:, feature_index])
    VIF = 1 / (1 - r2)
    print(estimator.coef_, r2, VIF)


def test_tree_based_collinear(threshold=0.5, output=False):
    """

    May not work.

    """

    from rfpimp import oob_dependences
    from sklearn.ensemble import RandomForestRegressor
    estimator = RandomForestRegressor(oob_score=True)

    # Not linear regression, not VIF
    # vif_list = 1 / (1 - df_dep.values)

    remove_feature_list = []
    while True:
        select_feature_list = np.full(X.shape[1], True)
        select_feature_list[remove_feature_list] = False
        X_selected = X[:, select_feature_list]
        df_dep = oob_dependences(estimator, pd.DataFrame(X).drop_duplicates(subset=[37, 38])).sort_index()
        score_list = df_dep.values.flatten().tolist()

        if max(score_list) < threshold:
            if output:
                print(list(zip(score_list, labels[select_feature_list])))
            break

        # Remove Max Score Feature
        max_index = score_list.index(max(score_list))
        remove_feature_list.append(
            np.where(select_feature_list)[0][max_index]
        )

    return select_feature_list


def test_select_feature():
    for threshold in np.arange(0.1, 1, 0.1):
        select_features = test_tree_based_collinear(threshold)
        X_selected = X[:, select_features]

        estimator = WeightModel(
            LinearRegression(),
            distance_measure='euclidean',
            kernel_type='bisquare',
            neighbour_count=0.1,

            cache_data=True, cache_estimator=True
        )
        estimator.fit(X_selected, y_true, xy_vector, time)
        print(f'Score {threshold}')
        print(f'Feature Num: {np.nonzero(select_features)[0].shape[0]}')
        print(f'Feature Name: {labels[select_features]}')
        print(estimator.llocv_score_)
        print()


from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

X, y, xy_vector, time = load_HP()


def test_importance():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,

        cache_data=True, cache_estimator=True
    )

    # For non-linear interaction test.
    model.local_estimator = RandomForestRegressor()

    model.fit(X[:, :5], y, [xy_vector, time])
    is_local = model.importance_score_local()
    is_global = model.importance_score_global()
    is_interaction = model.interaction_score_global()

    print(is_local, is_global, is_interaction)


if __name__ == '__main__':
    test_importance()


import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

from georegression.spatial_temporal_moran import spatiotemporal_MI, \
    spatiotemporal_LMI, plot_moran_diagram, STMI
from georegression.test.data import load_HP, load_TOD
from georegression.weight_model import WeightModel

# (X, y_true, xy_vector, time) = load_HP()
(X, y_true, xy_vector, time) = load_TOD()


def test_moran():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,
    )
    model.fit(X, y_true, [xy_vector, time])
    print(model.llocv_score_)

    spatiotemporal_MI(np.abs(model.local_residual_), model.weight_matrix_)
    STMI(np.abs(model.local_residual_), model.weight_matrix_)

    global_moran = spatiotemporal_MI(np.abs(model.local_residual_), model.weight_matrix_)
    local_moran = spatiotemporal_LMI(np.abs(model.local_residual_), model.weight_matrix_)

    print(global_moran, local_moran)


def test_moran_diagram():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.5,
    )
    model.fit(X, y_true, [xy_vector, time])
    print(model.llocv_score_)

    plot_moran_diagram(y_true, model.weight_matrix)
    plt.savefig('test_moran_diagram.png')


if __name__ == '__main__':
    test_moran()


from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from georegression.test.data import load_HP
from georegression.visualize.pd import select_partials
from georegression.weight_model import WeightModel


X, y, xy_vector, time = load_HP()


def test_partial_dependence():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.5,

        cache_data=True, cache_estimator=True
    )

    model.fit(X[:100, :10], y[:100], [xy_vector[:100], time[:100]])
    model.partial_dependence()
    model.local_ICE()


if __name__ == '__main__':
    test_partial_dependence()




from scipy.sparse import csr_array
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.metrics import r2_score

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel

from time import time as t

X, y_true, xy_vector, time = load_HP()


def test_performance():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time],
        [xy_vector, time],
        "euclidean",
        "bisquare",
        None,
        None,
        0.1,
        None,
    )
    weight_matrix = csr_array(weight_matrix)

    estimator = StackingWeightModel(
        ExtraTreesRegressor(max_depth=1, n_estimators=100),
        neighbour_leave_out_rate=0.1,
        use_numba=True,
    )

    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)
    # Time taken to predict in single thread: 19.013251066207886
    # Time taken to predict in parallel: 43.77483797073364
    # 0.7931079729983194
    # 0.823086329159848


if __name__ == "__main__":
    test_performance()


from sklearn.tree import ExtraTreeRegressor

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP

X, y_true, xy_vector, time = load_HP()


def test_logging():
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.1

    estimator = StackingWeightModel(
        ExtraTreeRegressor(splitter="random", max_depth=1),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        neighbour_leave_out_rate=0.1,
    )

    estimator.fit(X, y_true, [xy_vector, time])


if __name__ == "__main__":
    test_logging()


from scipy.sparse import csr_array
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.metrics import r2_score

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_matrix import weight_matrix_from_points
from georegression.weight_model import WeightModel

from time import time as t

X, y_true, xy_vector, time = load_HP()

def test_neighbour_leave_out_shrink_rate():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time], [xy_vector, time], "euclidean", "bisquare", None, None, 0.1, None
    )

    estimator = StackingWeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        neighbour_leave_out_rate=0.1,
        use_numba=False,
        neighbour_leave_out_shrink_rate=0.3,
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)

    weight_matrix = csr_array(weight_matrix)
    estimator.use_numba = True
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)


def test_meta_fitting_shrink_rate():
    weight_matrix = weight_matrix_from_points(
        [xy_vector, time], [xy_vector, time], "euclidean", "bisquare", None, None, 0.1, None
    )
    
    estimator = StackingWeightModel(
        ExtraTreeRegressor(max_depth=1, splitter="random"),
        neighbour_leave_out_rate=0.1,
        use_numba=False,
        meta_fitting_shrink_rate=0.9,
    )
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)
    
    weight_matrix = csr_array(weight_matrix)
    estimator.use_numba = True
    estimator.fit(X, y_true, [xy_vector, time], weight_matrix=weight_matrix)
    print(estimator.llocv_score_)
    print(estimator.llocv_stacking_)



if __name__ == '__main__':
    # test_neighbour_leave_out_shrink_rate()
    test_meta_fitting_shrink_rate()

from time import time as t

from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV
from sklearn.tree import DecisionTreeRegressor

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP, load_ESI
from georegression.weight_model import WeightModel

X, y_true, xy_vector, time = load_ESI()


def test_large_data():
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.01

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )

    t_start = t()

    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")

    t_end = t()
    print(f"Time: {t_end - t_start}")

if __name__ == '__main__':
    test_large_data()

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

X, y, xy_vector, time = load_HP()


def test_predict():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,

        cache_data=True, cache_estimator=True
    )

    model.fit(X, y, [xy_vector, time])
    prediction_by_weight = model.predict_by_weight(X[:10], [xy_vector[:10], time[:10]])
    prediction_by_fit = model.predict_by_fit(X[:10], [xy_vector[:10], time[:10]])

    # Prediction outperform the local_predict because of the data leak.
    print(
        r2_score(y[:10], model.local_predict_[:10]),
        r2_score(y[:10], prediction_by_weight),
        r2_score(y[:10], prediction_by_fit),
    )


if __name__ == '__main__':
    test_predict()


import numpy as np

from georegression.neighbour_utils import sample_neighbour


def test_sample_neighbour():
    # Generate ramdom weight matrix
    weight_matrix = np.random.rand(100, 100)

    # Set 0 randomly
    weight_matrix[weight_matrix < 0.99] = 0

    # Sample neighbour
    neighbour_matrix_sampled = sample_neighbour(weight_matrix, sample_rate=0.5)

if __name__ == '__main__':
    test_sample_neighbour()


from time import time as t

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import RidgeCV
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

# X, y_true, xy_vector, time = load_TOD()
# X, y_true, xy_vector, time = load_ESI()
X, y_true, xy_vector, time = load_HP()


def test_stacking():
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.01

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )

    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")


def test_alpha():
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.01

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        alpha=10,
    )

    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")

    for local_estimator in model.local_estimator_list:
        print(local_estimator.meta_estimator.coef_)
        break

    # For alpha=0.1, stacking_score = 0.5750569627981988
    """
    Coefficients of first stacking estimator:
    [-0.46379083 -0.38453714  0.39963185  0.01484807  0.16410479 -0.59694787
      0.21276714  0.11330034  0.29212005 -0.20581994  0.07942222  0.92542167
      0.44300962  0.26067723  0.03980381 -0.32809317  0.17886772  0.26176183
      0.31227637  0.12423833  0.23946592]
    """

    # For alpha=10, stacking_score = 0.9403979818713938
    """
    Coefficients of first stacking estimator:
    [ 0.07789433 -0.03072463  0.18275214 -0.00193438  0.05766076 -0.00123777
      0.13473063  0.1755927   0.0568057   0.0234573   0.14681941  0.03860493
     -0.06496593  0.1208457   0.06717717  0.0523331   0.0167307   0.14635798
     -0.03296376 -0.04416956  0.26379955]
    """


def test_estimator_sample():
    local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.1

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        estimator_sample_rate=0.1,
    )

    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        estimator_sample_rate=0.5,
    )
    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")

    model = StackingWeightModel(
        local_estimator,
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        estimator_sample_rate=None,
    )
    model.fit(X, y_true, [xy_vector, time])
    print(f"{model.llocv_score_}, {model.llocv_stacking_}")


def test_performance():
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.1

    estimator = StackingWeightModel(
        ExtraTreeRegressor(splitter="random", max_depth=1),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
        neighbour_leave_out_rate=0.1,
        # estimator_sample_rate=0.1,
    )

    t1 = t()
    estimator.fit(X, y_true, [xy_vector, time])
    t2 = t()
    print(t2 - t1, estimator.llocv_score_, estimator.llocv_stacking_)
    # neighbour_count = 0.1 neighbour_leave_out_rate=0.1 17.498192310333252 0.7670304978651743 0.8134413795992002

    estimator = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )
    t2 = t()
    estimator.fit(X, y_true, [xy_vector, time])
    t3 = t()

    print(t3 - t2, estimator.llocv_score_)
    # neighbour_count = 0.1 n_estimators=50 34.99542546272278 0.8096408618045396

    estimator = WeightModel(
        RidgeCV(),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )
    estimator.fit(X, y_true, [xy_vector, time])
    t4 = t()
    # neighbour_count = 0.1 5.488587141036987 0.7706704632683226

    print(t4 - t3, estimator.llocv_score_)


def test_stacking_not_leaking():
    # local_estimator = DecisionTreeRegressor(splitter="random", max_depth=2)
    # local_estimator = DecisionTreeRegressor()
    local_estimator = RandomForestRegressor(n_estimators=50)
    # local_estimator = RidgeCV()
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None

    def fit_wrapper():
        estimator = StackingWeightModel(
            local_estimator,
            distance_measure,
            kernel_type,
            distance_ratio,
            bandwidth,
            neighbour_count,
            neighbour_leave_out_rate=leave_out_rate,
        )

        t1 = t()
        estimator.fit(X, y_true, [xy_vector, time])
        t2 = t()
        print("neighbour_count =", neighbour_count, "leave_out_rate =", leave_out_rate)
        print(
            "time =",
            t2 - t1,
            "llocv_score =",
            estimator.llocv_score_,
            "llocv_stacking =",
            estimator.llocv_stacking_,
        )

    neighbour_count = 0.1
    leave_out_rate = 0.1

    fit_wrapper()

    estimator = WeightModel(
        RandomForestRegressor(n_estimators=50),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )
    t2 = t()
    estimator.fit(X, y_true, [xy_vector, time])
    t3 = t()
    print(t3 - t2, estimator.llocv_score_)

    for depth in range(1, 31, 2):
        # local_estimator = DecisionTreeRegressor(max_depth=depth)
        # fit_wrapper()
        pass

    for neighbour_count in np.arange(0.05, 0.35, 0.05):
        for leave_out_rate in np.arange(0.05, 0.35, 0.05):
            # fit_wrapper()
            pass


if __name__ == "__main__":
    # test_stacking()
    # test_alpha()
    # test_estimator_sample()
    test_performance()
    # test_stacking_not_leaking()




from time import time as t

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import RidgeCV
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor

from georegression.stacking_model import StackingWeightModel
from georegression.test.data import load_HP
from georegression.weight_model import WeightModel

X, y_true, xy_vector, time = load_HP()


def test_logging():
    distance_measure = "euclidean"
    kernel_type = "bisquare"
    distance_ratio = None
    bandwidth = None
    neighbour_count = 0.01

    estimator = WeightModel(
        RandomForestRegressor(n_estimators=25),
        distance_measure,
        kernel_type,
        distance_ratio,
        bandwidth,
        neighbour_count,
    )

    estimator.fit(X, y_true, [xy_vector, time])

if __name__ == '__main__':
    test_logging()



import hyperopt
from hyperopt import fmin, tpe, hp, Trials
import numpy as np


# Define the search space
def custom_prior(name):
    return hp.choice(
        name,
        [hp.normal(name + "_normal", 0.25, 0.05), hp.uniform(name + "_uniform", 0, 1)],
    )


space = {"param_name": custom_prior("param_name")}


# Define the objective function
def objective(params):

    if params["param_name"] < 0.25:
        return {"status": hyperopt.STATUS_FAIL}

    # Ensure the parameter is within [0, 1]
    params["param_name"] = max(0, min(1, params["param_name"]))

    # Fictional objective function: (x-0.25)^2, the minimum is at x=0.25
    loss = (params["param_name"] - 0.25) ** 2

    return {"loss": loss, "status": hyperopt.STATUS_OK}


# Optimization
trials = Trials()
best = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=1000, trials=trials)

print(f"Best hyperparameter value: {best['param_name']}")


from hyperopt import hp, fmin, tpe, Trials


# 1. Setting up the Custom Search Space
def custom_prior(name):
    def sample_and_clip():
        # 80% chance of choosing the normal distribution
        chosen_distr = hp.choice(name + "_choice", [0, 1])
        if chosen_distr == 0:
            val = hp.normal(name + "_normal", 0.25, 0.05)
            return max(0, min(1, val))  # Clip the value to the [0, 1] range
        else:
            return hp.uniform(name + "_uniform", 0, 1)

    return sample_and_clip()


space = {"param_name": custom_prior("param_name")}


# 2. Defining the Objective Function
def objective(params):
    # Clipping the parameter value to ensure it's in the [0, 1] range
    params["param_name"] = max(0, min(1, params["param_name"]))

    # Hypothetical objective function (replace with your actual function)
    loss = (params["param_name"] - 0.3) ** 2  # Minimized when param_name is 0.3

    return loss


# 3. Optimization
trials = Trials()
best = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=100, trials=trials)

print("Best hyperparameter value:", best)




from sklearn.metrics import *

from georegression.test.data import load_HP

(X, y_true, xy_vector, time) = load_HP()


def test_rf():
    from sklearn.ensemble import RandomForestRegressor

    print('RandomForest Model')
    estimator = RandomForestRegressor(n_estimators=1000, n_jobs=-1, oob_score=True)
    estimator.fit(X, y_true)
    print(estimator.oob_score_)

    MAPE = mean_absolute_percentage_error(y_true, estimator.oob_prediction_)
    print(MAPE)

    return estimator.predict(estimator.oob_prediction_)



import numpy as np
from sklearn.metrics import *
from georegression.test.data import load_HP

(X, y_true, xy_vector, time) = load_HP()


def test_rf():
    from sklearn.ensemble import RandomForestRegressor

    print('RandomForest Model')
    estimator = RandomForestRegressor(n_estimators=1000, n_jobs=-1, oob_score=True)
    estimator.fit(X, y_true)
    print(estimator.oob_score_)

    MAPE = mean_absolute_percentage_error(y_true, estimator.oob_prediction_)
    print(MAPE)

    return estimator.oob_prediction_


def test_ols():
    from sklearn.linear_model import LinearRegression
    e = LinearRegression()
    e.fit(X, y_true)
    y_predict = e.predict(X)
    r2 = r2_score(y_true, y_predict)
    print(r2)

    MAPE = mean_absolute_percentage_error(y_true, y_predict)
    print(MAPE)




import pandas as pd
import plotly.graph_objects as go
import numpy as np
from matplotlib import cm
import plotly.express as px

colors = px.colors.sequential.RdBu
def test_camera_static():
    fig = go.Figure(data=[
        # Point
        go.Scatter3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0], hovertext=[1, 2, 3, 4],
                     marker={
                         # 'colorscale': px.colors.qualitative.Plotly,
                         # 'colorscale': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88]),
                         'colorscale': colors,

                         'colorbar': {
                             'x': 1,
                             'tickformat': '.4s',
                             # 'tickvals': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88])
                             # 'tickvals': [0, 1]
                         },
                         # 'color': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88])
                         'color': [0, 0.25, 0.25, 0.88]
                     }),
        # Line
        go.Scatter3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 1, 1, 0], hovertext=[1, 2, 3, 4],
                     mode='lines',
                     line={
                         # 'colorbar': {'x': 1.1},
                         # 'color': [1, 2, 3, 4]
                     }),
        # Bottom surface
        go.Mesh3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0], color='rgb(0, 255, 0)', hoverinfo='skip',
                  hovertemplate=None, colorbar={'x': 1.1}),
        # Top surface
        go.Mesh3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[1] * 4, color='rgb(255, 255, 0)', hoverinfo='skip',
                  hovertemplate=None)
    ])
    fig.update_layout(
        # Set axis ratio
        scene_aspectmode='manual',
        scene_aspectratio=dict(x=1, y=1, z=3),
        # Clear margin
        # automargin=True
        margin=dict(l=0, r=0, t=50, b=0, pad=0),
        hovermode='x',
        # hoverdistance=-1
        spikedistance=1
    )

    fig.update_layout(
        title={
            'text': f"Temporal Partial Dependency of Feature",
            'y': 0.95,  # new
            'x': 0.4,
            'xanchor': 'center',
            'yanchor': 'top'  # new
        },
        scene=dict(
            xaxis_title='X Axis Title',
            yaxis_title='Y Axis Title',
            zaxis_title='Z Axis Title',
        ),
        legend_title="Cluster Legend",
        font=dict(
            size=18,
            color="RebeccaPurple"
        ))

    # fig.update_xaxes(automargin=True)
    # fig.update_yaxes(automargin=True)

    fig2 = px.scatter_3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0],
                         color=np.array([0, 0.25, 0.25, 0.88]).astype(str))
    fig.add_traces(fig2.data)



    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=0.25, y=13.25, z=13.25),
        projection_type="orthographic"
    )

    fig.update_layout(scene_camera=camera)

    # Get and set the aspect ratio of the scene
    x_aspect = fig.layout.scene.aspectratio.x
    y_aspect = fig.layout.scene.aspectratio.y
    z_aspect = fig.layout.scene.aspectratio.z

    fig.update_layout(
        scene_aspectratio={
            "x": x_aspect * 0.1,
            "y": y_aspect * 0.1,
            "z": z_aspect * 0.1,
        },
    )

    aspectratio = dict(x=x_aspect, y=y_aspect, z=z_aspect),

    fig.write_html(f'test_plot.html', include_plotlyjs='cdn')
    fig.write_image(f'test_plot.png', width=1080, height=1920, scale=5)

if __name__ == '__main__':
    test_camera_static()

from sklearn.linear_model import LinearRegression

from georegression.test.data import load_HP
from georegression.visualize.importance import global_importance_plot
from georegression.weight_model import WeightModel

X, y, xy_vector, time = load_HP()


def test_importance_plot():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.1,

        cache_data=True, cache_estimator=True
    )

    # For non-linear interaction test.
    # model.local_estimator = RandomForestRegressor()

    model.fit(X[:, :5], y, [xy_vector, time])
    is_global = model.importance_score_global()
    global_importance_plot(model.permutation_score_decrease_)

    # is_interaction = model.interaction_score_global()


if __name__ == '__main__':
    test_importance_plot()


from georegression.test.data import load_HP
from georegression.test.visualize import get_toy_model
from georegression.visualize.pd import partials_plot_3d, features_partial_cluster, partial_plot_2d, \
    partial_compound_plot, choose_cluster_typical

model = get_toy_model()
features_embedding, features_cluster_label, _ = features_partial_cluster(model.feature_partial_)

X, y, xy_vector, time = load_HP()


def test_compound_plot():
    partial_figs, embedding_figs, cluster_figs, compass_figs = partial_compound_plot(
        xy_vector[:100], time[:100], model.feature_partial_,
        features_embedding, features_cluster_label,
    )


def test_pd_2d_plot():
    # cluster_typical = choose_cluster_typical(cluster_embedding, cluster_label)
    # partial_plot_2d(
    #     model.feature_partial_, cluster_label, cluster_typical,
    #     alpha_range=[0.1, 1], width_range=[0.5, 3], scale_power=1.5
    # )

    cluster_typical = [
        choose_cluster_typical(embedding, cluster)
        for embedding, cluster in zip(features_embedding, features_cluster_label)
    ]
    partial_plot_2d(
        model.feature_partial_, features_cluster_label, cluster_typical,
        alpha_range=[0.3, 1], width_range=[0.5, 3], scale_power=1.5
    )


def test_pd_3d_plot():
    partials_plot_3d(
        model.feature_partial_, model.coordinate_vector_list[1], cluster_labels=features_cluster_label,
        # quantile=[0, 0.2, 0.8, 1],
    )


if __name__ == '__main__':
    test_pd_2d_plot()
    pass


import pandas as pd
import plotly.graph_objects as go
import numpy as np
from matplotlib import cm
import plotly.express as px


def value_to_rgb_str(value, colormap_name='viridis'):
    colormap = cm.get_cmap(colormap_name)
    rgba_tuple = colormap(value)
    return f'rgb({int(rgba_tuple[0] * 255)}, {int(rgba_tuple[1] * 255)}, {int(rgba_tuple[2] * 255)})'


def value_list_to_rgb_str_list(value_list):
    return [
        value_to_rgb_str(value)
        for value in value_list
    ]


colors = px.colors.sequential.RdBu


def test_mesh_3d():
    fig = go.Figure(data=[
        # Point
        go.Scatter3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0], hovertext=[1, 2, 3, 4],
                     marker={
                         # 'colorscale': px.colors.qualitative.Plotly,
                         # 'colorscale': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88]),
                         'colorscale': colors,

                         'colorbar': {
                             'x': 1,
                             'tickformat': '.4s',
                             # 'tickvals': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88])
                             # 'tickvals': [0, 1]
                         },
                         # 'color': value_list_to_rgb_str_list([0, 0.25, 0.25, 0.88])
                         'color': [0, 0.25, 0.25, 0.88]
                     }),
        # Line
        go.Scatter3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 1, 1, 0], hovertext=[1, 2, 3, 4],
                     mode='lines',
                     line={
                         # 'colorbar': {'x': 1.1},
                         # 'color': [1, 2, 3, 4]
                     }),
        # Bottom surface
        go.Mesh3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0], color='rgb(0, 255, 0)', hoverinfo='skip',
                  hovertemplate=None, colorbar={'x': 1.1}),
        # Top surface
        go.Mesh3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[1] * 4, color='rgb(255, 255, 0)', hoverinfo='skip',
                  hovertemplate=None)
    ])
    fig.update_layout(
        # Set axis ratio
        scene_aspectmode='manual',
        scene_aspectratio=dict(x=1, y=1, z=3),
        # Clear margin
        # automargin=True
        margin=dict(l=0, r=0, t=50, b=0, pad=0),
        hovermode='x',
        # hoverdistance=-1
        spikedistance=1
    )

    fig.update_layout(
        title={
            'text': f"Temporal Partial Dependency of Feature",
            'y': 0.95,  # new
            'x': 0.4,
            'xanchor': 'center',
            'yanchor': 'top'  # new
        },
        scene=dict(
            xaxis_title='X Axis Title',
            yaxis_title='Y Axis Title',
            zaxis_title='Z Axis Title',
        ),
        legend_title="Cluster Legend",
        font=dict(
            size=18,
            color="RebeccaPurple"
        ))

    # fig.update_xaxes(automargin=True)
    # fig.update_yaxes(automargin=True)

    fig2 = px.scatter_3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0],
                         color=np.array([0, 0.25, 0.25, 0.88]).astype(str))
    fig.add_traces(fig2.data)

    fig.write_html(f'test_plot.html')


def test_discrete():
    fig = px.scatter_3d(x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0],
                        color=['T1', 'T2', 'T3', 'T4'])
    fig.write_html(f'test_plot2.html')


def test_lengend():
    import plotly.express as px
    df = px.data.gapminder()
    fig = px.line(df, y="lifeExp", x="year", color="continent", line_group="country",
                  line_shape="spline", render_mode="svg",
                  color_discrete_sequence=px.colors.qualitative.G10,
                  title="Built-in G10 color sequence")

    fig.show()


def test_surface_colorbar():
    fig = go.Figure([
        go.Mesh3d(
            x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[0, 0, 0, 0], color='rgb(0, 255, 0)', hoverinfo='skip',
            hovertemplate=None,
            name='Surface Group',
            legendgroup='surface group',
            showlegend=True,
            showscale=True,
            colorbar={'x': 1.1}
        ),
        go.Mesh3d(
            x=[0, 1, 1, 0], y=[0, 0, 1, 1], z=[1, 1, 1, 1], hoverinfo='skip',
            hovertemplate=None,
            name='surface2',
            legendgroup='surface group',
            showlegend=False,
            # colorbar={'x': 1.1}
        )
    ])

    fig.update_layout(
        scene=dict(
            xaxis=dict(
                backgroundcolor="rgb(200, 200, 230)",
                gridcolor="white",
            ),
        )
    )

    fig.write_html(f'test_plot.html')


def test_axes_ratio():
    import plotly.graph_objects as go

    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=[0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3],
        y=[0, 0, 1, 1, 3, 3, 2, 2, 3, 3, 1, 1, 0, 0]
    ))

    fig.update_layout(
        width=800,
        height=500,
        title="fixed-ratio axes"
    )
    fig.update_xaxes(
        range=(-0.5, 3.5),
        constrain='domain'
    )
    fig.update_yaxes(
        scaleanchor="x",
        scaleratio=1,
    )

    fig.show()

if __name__ == '__main__':
    pass


def test_pd_sample():
    from georegression.test.visualize import get_toy_model
    model = get_toy_model()

    from georegression.visualize.pd import sample_partial
    sample_partial(
        model.feature_partial_[0], quantile=[0.1, 0.5]
    )

    from georegression.visualize.pd import partial_cluster
    _, cluster_label, _ = partial_cluster(model.feature_partial_[0])
    sample_partial(
        model.feature_partial_[0], sample_size=0.1, cluster_label=cluster_label
    )


if __name__ == '__main__':
    test_pd_sample()


from sklearn.linear_model import LinearRegression

from georegression.test.data import load_HP
from georegression.visualize.pd import features_partial_cluster
from georegression.weight_model import WeightModel
from georegression.visualize.scatter import scatter_3d

X, y, xy_vector, time = load_HP()


def test_scatter():
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.5,

        cache_data=True, cache_estimator=True
    )

    # Continuous case
    scatter_3d(
        xy_vector[:100], time[:100], y[:100],
        'Title', 'Continuous'
    )

    # Cluster case
    model.fit(X[:100, :10], y[:100], [xy_vector[:100], time[:100]])
    model.partial_dependence()
    feature_distance, feature_cluster_label, distance_matrix, cluster_label = features_partial_cluster(
        xy_vector[:100], time[:100], model.feature_partial_)

    scatter_3d(
        xy_vector[:100], time[:100], cluster_label[:100],
        'Title', 'Cluster', is_cluster=True
    )


if __name__ == '__main__':
    test_scatter()


def get_toy_model():
    from georegression.test.data import load_HP
    X, y, xy_vector, time = load_HP()
    from georegression.weight_model import WeightModel
    from sklearn.linear_model import LinearRegression
    model = WeightModel(
        LinearRegression(),
        distance_measure='euclidean',
        kernel_type='bisquare',
        neighbour_count=0.5,
        cache_data=True, cache_estimator=True
    )
    model.fit(X[:100, :10], y[:100], [xy_vector[:100], time[:100]])
    model.partial_dependence()

    return model


import matplotlib.pyplot as plt


import numpy as np
from matplotlib import transforms
from sklearn.neighbors import KernelDensity


def plot_ale(fvals, ale, x):
    fig, ax1 = plt.subplots(figsize=(12, 6))

    ax1.plot(fvals, ale, zorder=2, label='STALE')
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)

    ax1.set_xlabel('Feature value')
    ax1.set_ylabel('ALE')

    # Density distribution
    ax2 = ax1.twinx()
    ax2.hist(x, bins=10, density=False, alpha=0.3, color='gray', zorder=1, label='Density')
    ax2.grid(False)

    ax2.set_ylabel('Density')

    return fig


import numpy as np
from matplotlib import pyplot as plt

from georegression.visualize import default_folder


def global_importance_plot(importance_matrix, labels=None, index=True, folder_=default_folder):
    """
    Args:
        importance_matrix (np.ndarray): Shape(Feature, n_repeats).
        labels (list):
        index (): Whether add index before labels.
        folder_ ():

    """

    # Default labels if not provided.
    if labels is None:
        labels = [f'Feature {i + 1}' for i in range(importance_matrix.shape[0])]
    labels = np.array(labels)

    # Add index for labels
    if index:
        labels = [f'{i + 1}. {labels[i]}' for i in range(labels.shape[0])]
    labels = np.array(labels)

    # Sort by the mean of importance value
    importance_mean = np.mean(importance_matrix, axis=1)
    sort_index = np.argsort(importance_mean)
    importance_matrix = importance_matrix[sort_index, :]
    labels = labels[sort_index]
    importance_mean = importance_mean[sort_index]

    # Boxplot
    plt.figure(figsize=(10, 6))
    plt.boxplot(importance_matrix.T, vert=False, labels=labels)

    plt.xlabel('Global Importance')
    plt.ylabel('Feature name')
    plt.title('Global Importance of Independent Features')
    plt.tight_layout()
    plt.savefig(folder_ / 'ImportanceBoxplot.png')
    plt.clf()

    # Integrate two plot into one figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8), sharex='all')
    ax1.barh(labels, importance_mean, height=0.7)
    ax1.set_xlabel('Importance value')
    ax1.set_ylabel('Feature name')
    ax1.set_title('Mean value of feature importance')
    ax1.margins(0.02)

    ax2.boxplot(importance_matrix.T, vert=False)
    ax2.set_xlabel('Importance value')
    ax2.set_yticklabels([])
    ax2.set_title('Boxplot of feature importance')
    ax2.margins(0.02)

    fig.suptitle('Global Importance of Independent Features\n')
    fig.tight_layout()
    fig.savefig(folder_ / 'ImportancePlot.png')

# TODO: Add interaction 2D hot-plot.


import math
import time
from os.path import join
from pathlib import Path

import matplotlib
import numpy as np
from joblib import Parallel, delayed
from matplotlib import cm, pyplot as plt
from plotly.subplots import make_subplots
from scipy.cluster.hierarchy import dendrogram
from scipy.spatial.distance import cdist, pdist, squareform
from sklearn.cluster import AgglomerativeClustering
from sklearn.inspection import PartialDependenceDisplay

import plotly.graph_objects as go
import plotly.express as px

from scipy.stats import logistic

from georegression.visualize.scatter import scatter_3d
from georegression.visualize.utils import vector_to_color, range_margin

from georegression.visualize import default_folder


def sample_partial(partial, sample_size=None, quantile=None, cluster_label=None, random_state=1003):
    """
    Use random sample or quantile/percentile to get the subset of partial data.

    Args:
        partial (np.ndarray): Shape(N, 2)
        sample_size (): Int for specific count. Float for rate.
        quantile ():
        random_state:

    Returns:

    """
    # Set random state
    if random_state is not None:
        np.random.seed(random_state)

    N = partial.shape[0]

    if sample_size is None and quantile is None:
        raise Exception('No selection method is chosen.')
    if sample_size is not None and quantile is not None:
        raise Exception('Only one selection method is allowed.')

    # Select by sample
    if sample_size is not None:
        # Proportional sample.
        if isinstance(sample_size, float):
            sample_size = int(sample_size * N)

        # Ensure at least one sample for each cluster.
        if cluster_label is not None:
            # Sample size is proportional to cluster size. bincount cannot handle negative values (-1 for un-clustered label).
            cluster_values, cluster_sizes = np.unique(cluster_label, return_counts=True)
            cluster_sample_sizes = np.ceil(cluster_sizes * sample_size / N).astype(int)
            # Ensure at least one sample for each cluster. Sample size is no larger than cluster size.
            cluster_sample_sizes = np.clip(cluster_sample_sizes, 1, cluster_sizes)

            cluster_sample_indices = []
            for cluster_value, cluster_sample_size in zip(cluster_values, cluster_sample_sizes):
                cluster_sample_indices.append(
                    np.random.choice(np.where(cluster_label == cluster_value)[0], cluster_sample_size, replace=False))
            sample_indices = np.concatenate(cluster_sample_indices)
        else:
            sample_indices = np.random.choice(N, sample_size, replace=False)

        return sample_indices

    # Select by quantile
    if quantile is not None:
        def inner_average(x):
            # TODO: Use weighted average.
            return np.average(x)

        v_inner_average = np.vectorize(inner_average)
        feature_y_average = v_inner_average(partial[:, 1])
        quantile_values = np.quantile(feature_y_average, quantile, interpolation='nearest')

        quantile_indices = []
        for quantile_value in quantile_values:
            # Select the index of value where they first appear.
            quantile_index = np.where(feature_y_average == quantile_value)[0][0]
            quantile_indices.append(quantile_index)

        return quantile_indices


def sample_suffix(sample_size=None, quantile=None):
    if sample_size is not None:
        suffix = f'_Sample{sample_size}'
    elif quantile is not None:
        suffix = '_Q' + ';'.join(map(str, quantile))
    else:
        suffix = ''

    return suffix


def partial_plot_2d(
        feature_partial, cluster_vector, cluster_typical,
        weight_style=True, alpha_range=None, width_range=None, use_sigmoid=True, scale_power=1,
        folder_=default_folder
):
    """

    Args:

        feature_partial (): Shape(Feature, N, 2)
        cluster_vector (): Shape(N,) or Shape(Feature, N)
        cluster_typical (): Shape(n_cluster) or Shape(Feature, n_cluster)
        alpha_range ():
        width_range ():
        scale_power ():
        use_sigmoid ():
        weight_style (bool):
        folder_ ():

    Returns:

    """

    if alpha_range is None:
        alpha_range = [0.1, 1]
    if width_range is None:
        width_range = [0.5, 3]

    if len(cluster_vector.shape) == 1:
        is_integrated = True
    else:
        is_integrated = False

    feature_count = len(feature_partial)
    local_count = len(feature_partial[0])

    # Matplotlib Plot Gird
    col = 3
    row = math.ceil(feature_count / col)
    col_length = 3
    row_length = 2

    # Close interactive mode
    plt.ioff()

    fig, axs = plt.subplots(
        ncols=col, nrows=row, sharey='none',
        figsize=(col * col_length, (row + 1) * row_length)
    )

    # Set figure size after creating to avoid screen resize.
    if plt.isinteractive():
        plt.gcf().set_size_inches(col * col_length, (row + 1) * row_length)

    # 2d-ndarray flatten
    axs = axs.flatten()

    # Remove null axis
    for ax_remove_index in range(col * row - feature_count):
        fig.delaxes(axs[- ax_remove_index - 1])

    # Iterate each feature
    for feature_index in range(feature_count):
        ax = axs[feature_index]

        if is_integrated:
            inner_vector = np.copy(cluster_vector)
            inner_typical = np.copy(cluster_typical)
        else:
            inner_vector = cluster_vector[feature_index]
            inner_typical = cluster_typical[feature_index]

        # Style the line by the cluster size.
        values, counts = np.unique(inner_vector, return_counts=True)
        if np.max(counts) == np.min(counts):
            style_ratios = np.ones(local_count)
        else:
            # style_ratios = (counts - np.min(counts)) / (np.max(counts) - np.min(counts))
            style_ratios = counts / local_count
            if use_sigmoid:
                style_ratios = (style_ratios - 0.5) * 10
                style_ratios = logistic.cdf(style_ratios)
            style_ratios = style_ratios ** scale_power
        # np.xx_like returns array having the same type as input array.
        style_alpha = np.zeros(local_count)
        style_width = np.zeros(local_count)
        for value, style_ratio in zip(values, style_ratios):
            cluster_index = np.nonzero(inner_vector == value)
            style_alpha[cluster_index] = alpha_range[0] + (alpha_range[1] - alpha_range[0]) * style_ratio
            style_width[cluster_index] = width_range[0] + (width_range[1] - width_range[0]) * style_ratio

        # Cluster typical selection
        inner_partial = feature_partial[feature_index, inner_typical]
        inner_vector = inner_vector[inner_typical]
        style_alpha = style_alpha[inner_typical]
        style_width = style_width[inner_typical]

        color_vector = vector_to_color(inner_vector, stringify=False)

        for local_index in range(len(inner_partial)):
            # Matplotlib 2D plot
            ax.plot(
                *inner_partial[local_index],
                **{
                    # Receive color tuple/list/array
                    "color": color_vector[local_index],
                    "alpha": style_alpha[local_index], "linewidth": style_width[local_index],
                    "label": f'Cluster {inner_vector[local_index]}'
                })
            ax.set_title(f'Feature {feature_index + 1}')

        # Individual file for each feature
        fig_ind = plt.figure(figsize=(5, 4), constrained_layout=True)
        for local_index in range(len(inner_partial)):
            fig_ind.gca().plot(
                *inner_partial[local_index],
                **{
                    # Receive color tuple/list/array
                    "color": color_vector[local_index],
                    "alpha": style_alpha[local_index], "linewidth": style_width[local_index],
                    "label": f'Cluster {inner_vector[local_index]}'
                }
            )
        plt.xlabel('Independent Value')
        plt.ylabel('Partial Dependent Value')

        # Padding according to the cluster label length.
        plt.title(f'SPPDP of Typical Cluster in Feature {feature_index + 1}', pad=10 + 15 * math.ceil(len(inner_vector) / 5))
        plt.legend(
            loc='lower center', bbox_to_anchor=(0.5, 1), ncol=5,
            columnspacing=0.2, fontsize='x-small', numpoints=2
        )
        fig_ind.savefig(
            folder_ / f'SPPDP_Typical{"_Merged" if is_integrated else ""}{feature_index + 1}',
            dpi=300
        )
        fig_ind.clear()

    fig.supxlabel('Independent Value')
    fig.supylabel('Partial Dependent Value')

    fig.tight_layout(h_pad=1.5)
    fig.subplots_adjust(top=0.85)
    fig.suptitle(f'SPPDP')

    if is_integrated:
        handles, labels = ax.get_legend_handles_labels()
        # put the center upper edge of the bounding box at the coordinates(bbox_to_anchor)
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.965), ncol=6)

    fig.savefig(folder_ / f'SPPDP{"_Merged" if is_integrated else ""}.png')
    fig.clear()


def partial_plot_3d(
        partial, temporal, cluster_label=None,
        sample_size=None, quantile=None,
        feature_name=None
):
    sample_indices = sample_partial(partial, sample_size, quantile, cluster_label)
    local_count = len(sample_indices)

    # Generate index label
    index_label = np.arange(len(partial))
    index_label = index_label[sample_indices]

    # Select partial, temporal and cluster.
    partial = partial[sample_indices]
    temporal = temporal[sample_indices]
    if cluster_label is not None:
        cluster_label = cluster_label[sample_indices]

    # Trace name and show_legend control
    if cluster_label is not None:
        colors = vector_to_color(cluster_label)

        def inner_naming(label):
            return f'Cluster {label}'

        v_naming = np.vectorize(inner_naming)
        names = v_naming(cluster_label)

        def inner_unique(label):
            _, first_index = np.unique(label, return_index=True)
            first_vector = np.zeros_like(label, dtype=bool)
            first_vector[first_index] = True
            return first_vector

        show_vector = np.apply_along_axis(inner_unique, -1, cluster_label)

    else:
        colors = vector_to_color(temporal)
        names = np.empty_like(temporal, dtype=object)
        show_vector = np.zeros_like(temporal, dtype=bool)

    # Each local corresponds to each trace
    trace_list = []
    for local_index in range(local_count):
        x = partial[local_index, 0]
        y = partial[local_index, 1]
        trace = go.Scatter3d(
            y=x, z=y,
            x=np.tile(temporal[local_index], len(x)),
            text=y,
            mode='lines',
            line=dict(
                # Receive Color String
                color=colors[local_index],
                width=2,
            ),
            name=names[local_index],
            legendgroup=names[local_index],
            showlegend=bool(show_vector[local_index]),
            hovertemplate=
            '<b>X Value</b>: %{y} <br />' +
            '<b>Time Slice</b>: %{x}  <br />' +
            f'<b>Index</b>: {index_label[local_index]}  <br />' +
            '<b>Partial Value</b>: %{z}  <br />'

        )
        trace_list.append(trace)

    fig = go.Figure(data=trace_list)
    if feature_name:
        title = f'SPPDP of Feature {feature_name}'
    else:
        title = 'SPPDP'

    fig.update_layout(
        title={
            'text': title,
            'xanchor': 'center',
            'x': 0.45,
            'yanchor': 'top',
            'y': 0.99,
        },
        margin=dict(l=0, r=0, t=50, b=0, pad=0),
        legend_title="Cluster Legend",
        font=dict(
            size=12,
        ),
        template="seaborn",
        font_family="Times New Roman"
    )

    # Fix range while toggling trace.
    y_max = np.max([np.max(y) for y in partial[:, 1]])
    y_min = np.min([np.min(y) for y in partial[:, 1]])

    x_max = np.max([np.max(x) for x in partial[:, 0]])
    x_min = np.min([np.min(x) for x in partial[:, 0]])

    fig.update_scenes(
        xaxis_title='Time Slice',
        xaxis_range=range_margin(vector=temporal),
        yaxis_title='Independent / X value',
        yaxis_range=range_margin(value_min=x_min, value_max=x_max),
        zaxis_title='Dependent / Partial Value',
        zaxis_range=range_margin(value_min=y_min, value_max=y_max),
    )

    return fig


def partials_plot_3d(
        feature_partial, temporal, cluster_labels=None,
        sample_size=None, quantile=None, feature_names=None
):
    """

    Args:
        feature_partial ():
        temporal ():
        cluster_labels (): Shape(N,) or Shape(Feature, N)
        sample_size ():
        quantile ():
        feature_names ():

    Returns:

    """

    feature_count = len(feature_partial)

    # Feature cluster or Integrated cluster.
    if cluster_labels is not None:
        # If Shape(N,). Else Shape(Feature, N).
        if len(cluster_labels.shape) == 1:
            cluster_labels = np.tile(cluster_labels.reshape(1, -1), (feature_count, 1))

    # Iterate each feature
    fig_list = []
    for feature_index in range(feature_count):
        fig = partial_plot_3d(
            partial=feature_partial[feature_index],
            temporal=temporal,
            cluster_label=cluster_labels[feature_index] if cluster_labels is not None else None,
            sample_size=sample_size,
            quantile=quantile,
            feature_name=feature_names[feature_index] if feature_names is not None else None,
        )

        fig_list.append(fig)

    return fig_list


def partial_distance(partial):
    """
    Calculation distance between partial lines.

    Args:
        partial (np.ndarray): partial result of a feature. Shape(N, 2)

    Returns:
        distance_matrix (np.ndarray): Shape(N, N)
    """

    N = partial.shape[0]
    line_distance_matrix = np.zeros((N, N))

    # Iterate each origin data point
    for origin_index, (x_origin, y_origin) in enumerate(partial):
        line_distance_list = []

        # Iterate each dest data point
        for x_dest, y_dest in partial[origin_index:]:

            # Overlapped range of two lines. (Max of line start point, Min of line end point)
            overlap_start = max(x_origin[0], x_dest[0])
            overlap_end = min(x_origin[-1], x_dest[-1])

            # No overlapped range.
            if overlap_start >= overlap_end:
                distance = np.inf
            else:
                # Get the point in both lines between the overlapped range.
                x_merge = np.unique(np.concatenate([x_origin, x_dest]))
                x_merge = x_merge[(overlap_start <= x_merge) & (x_merge <= overlap_end)]

                # Linear interpolate for the overlapped range.
                y_merge_origin = np.interp(x_merge, x_origin, y_origin)
                y_merge_dest = np.interp(x_merge, x_dest, y_dest)

                # Minimal square distance of two line. Optimal at -b/2a. a is coef of x^2, and b is coef of x.
                intercept = - np.sum(y_merge_origin - y_merge_dest) / len(x_merge)
                pointwise_distance = (y_merge_origin - y_merge_dest + intercept) ** 2

                # Weighting according to the point interval in the bi-direction.
                distance_weight = np.zeros_like(x_merge)
                distance_weight[1:-1] = x_merge[2:] - x_merge[:-2]
                distance_weight[0] = (x_merge[1] - x_merge[0]) * 2
                distance_weight[-1] = (x_merge[-1] - x_merge[-2]) * 2
                distance_weight = distance_weight / np.sum(distance_weight)

                distance = np.average(pointwise_distance, weights=distance_weight)

            line_distance_list.append(distance)
        line_distance_matrix[origin_index, origin_index:] = line_distance_list

    # Fill Infinity value by max distance.
    line_distance_matrix = np.nan_to_num(line_distance_matrix,
                                         posinf=line_distance_matrix[np.isfinite(line_distance_matrix)].max() * 2)

    # Fill the up triangular matrix.
    line_distance_matrix = line_distance_matrix + np.transpose(line_distance_matrix)

    return line_distance_matrix


def features_partial_distance(features_partial):
    """
    Calculation distance between partial lines.

    Args:
        features_partial (np.ndarray): Shape(Feature, N, 2)

    Returns:
        feature_distance (np.ndarray): Shape(Feature, N, N)

    """

    feature_count = features_partial.shape[0]

    # Shape(Feature, N, N)
    features_distance = Parallel(n_jobs=-1)(
        # Single feature based cluster. Iterate each feature
        delayed(partial_distance)(features_partial[feature_index])
        for feature_index in range(feature_count)
    )

    return np.array(features_distance)


def partial_cluster(
        partial=None, distance=None,
        n_neighbours=5, min_dist=0.1, n_components=2,
        min_cluster_size=10, min_samples=3, cluster_selection_epsilon=1,

        plot=False, select_clusters=False,
        plot_title='Condensed trees', plot_filename='CondensedTrees.png', plot_folder=default_folder,
):
    """
    Cluster data based on partial dependence result or derived distance matrix.

    Args:

        partial (np.ndarray): Shape(N, 2)
        distance (np.ndarray): Shape(N, N)
        n_neighbours:
        min_dist:
        n_components:
        min_cluster_size:
        min_samples:
        cluster_selection_epsilon:
        select_clusters:
        plot:
        plot_filename:
        plot_title:
        plot_folder:

    Returns:

    """

    from hdbscan import HDBSCAN
    from umap import UMAP

    if plot_title is None:
        plot_title = f'Condensed trees'
    if plot_filename is None:
        plot_filename = f'CondensedTrees.png'

    # Parameter check
    if partial is None and distance is None:
        raise Exception('Feature partial or feature distance matrix should be provided.')

    # Ensure feature distance is available.
    if distance is None:
        distance = partial_distance(partial)

    # TODO: Stable Reproducible result.
    # TODO: Range of UMAP embedding value?
    # Reduce dimension. Mapping the distance matrix to low dimension space embedding.
    # Standard embedding is used for visualization. Clusterable embedding is used for clustering.
    standard_embedding = UMAP(
        random_state=42, n_neighbors=n_neighbours, min_dist=min_dist, metric='precomputed'
    ).fit_transform(distance)
    if n_components == 2:
        clusterable_embedding = standard_embedding
    else:
        clusterable_embedding = UMAP(
            random_state=42, n_neighbors=n_neighbours, min_dist=min_dist, n_components=n_components,
            metric='precomputed'
        ).fit_transform(distance)

    model = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples,
                    cluster_selection_epsilon=cluster_selection_epsilon
                    ).fit(clusterable_embedding)

    if plot:
        model.condensed_tree_.plot(select_clusters=select_clusters)
        plt.title(plot_title)
        plt.savefig(plot_folder / plot_filename)
        plt.clf()

    return standard_embedding, model.labels_, distance


def features_partial_cluster(
        features_partial=None, features_distance=None,
        n_neighbours=5, min_dist=0.1, n_components=2,
        min_cluster_size=10, min_samples=3, cluster_selection_epsilon=1,
        select_clusters=False,
        labels=None, only_integrated=False, folder=default_folder,
):
    """
    Cluster data point based on partial dependency

    Args:
        labels (): Feature labels.
        features_distance ():
        folder ():
        n_neighbours ():
        min_dist ():
        n_components ():
        min_cluster_size ():
        min_samples ():
        cluster_selection_epsilon ():
        features_partial (np.ndarray): Shape(Feature, N, 2)
        select_clusters ():

    Returns:
        feature_embedding, feature_cluster_label, cluster_embedding, cluster_label

    """

    # TODO: More fine-tuning control on the multi-features and integrate-feature.

    # Parameter check
    if features_partial is None and features_distance is None:
        raise Exception('Feature partial or feature distance should be provided.')

    # Ensure feature distance is available.
    if features_distance is None:
        features_distance = features_partial_distance(features_partial)

    # Individual feature cluster
    feature_count = features_distance.shape[0]

    # Shape(Feature, N, 2)
    features_embedding = []
    # Shape(Feature, N)
    features_cluster_label = []
    for feature_index in range(feature_count):
        cluster_embedding, cluster_label, _ = partial_cluster(
            distance=features_distance[feature_index],
            n_neighbours=n_neighbours, min_dist=min_dist, n_components=n_components,
            min_cluster_size=min_cluster_size, min_samples=min_samples, cluster_selection_epsilon=cluster_selection_epsilon,
            select_clusters=select_clusters,
            plot_title=f'Condensed trees of Feature {feature_index + 1} {labels[feature_index] if labels is not None else ""}',
            plot_filename=f'CondensedTrees_{feature_index + 1}.png',
            plot_folder=folder

        )

        # Record feature label result
        features_cluster_label.append(cluster_label)
        features_embedding.append(cluster_embedding)

    features_cluster_label = np.array(features_cluster_label)
    features_embedding = np.array(features_embedding)

    return features_embedding, features_cluster_label, features_distance


def choose_cluster_typical(embedding, cluster_vector):
    """
    Return the index of typical items for each cluster.
    The typical item of a cluster is the centre of the cluster,
    which has the minimal summation of distance to others in the same cluster.

    Args:
        embedding ():
        cluster_vector ():

    Returns: List of index of typical items. The length of the list is the number of clusters.

    """
    cluster_typical_list = []
    cluster_value = np.unique(cluster_vector)
    for cluster in cluster_value:
        cluster_index_vector = np.nonzero(cluster_vector == cluster)[0]
        embedding_cluster = embedding[cluster_index_vector]
        cluster_typical_list.append(
            cluster_index_vector[np.argmin(np.sum(squareform(pdist(embedding_cluster)), axis=1))]
        )

    return cluster_typical_list


def embedding_plot(
        embedding, cluster, temporal_vector, feature_name
):
    """
    2D Embedding plot colored by cluster.

    Args:
        embedding (np.ndarray): Shape(N, 2)
        cluster (): Shape(N,)
        temporal_vector (): Shape(N,)
        feature_name ():
        filename ():
        folder ():

    Returns:

    """
    fig = go.Figure()

    local_index = np.arange(embedding.shape[0]).reshape(-1, 1)
    custom_data = np.concatenate([temporal_vector, local_index], axis=1)

    color = vector_to_color(cluster)

    for cluster_value in np.unique(cluster):
        cluster_index = cluster == cluster_value
        fig.add_trace(
            go.Scattergl(
                x=embedding[cluster_index, 0], y=embedding[cluster_index, 1],
                customdata=custom_data[cluster_index], mode='markers',
                # Name of trace for legend display
                name=f'Cluster {cluster_value}',
                legendgroup=f'Cluster {cluster_value}',
                marker={
                    'color': color[cluster_index],
                    'size': 5,
                },
                text=cluster[cluster_index],
                hovertemplate=
                f'<b>Cluster</b> :' + ' %{text} <br />' +
                f'<b>Time Slice</b> :' + ' %{customdata[0]} <br />' +
                f'<b>Index</b> :' + ' %{customdata[1]} <br />' +
                '<extra></extra>',
            )
        )

    title = f'Low dimension embedding'
    if feature_name:
        title += f' of {feature_name}'

    fig.update_layout(
        title=title,
        legend_title="clusters",
        template="seaborn",
        font_family="Times New Roman"
    )

    fig.update_xaxes(
        title="Embedding dimension X",
        range=range_margin(embedding[:, 0])
    )
    fig.update_yaxes(
        title="Embedding dimension Y",
        range=range_margin(embedding[:, 1]),
        scaleanchor="x",
        scaleratio=1,
    )

    return fig


def compass_plot(
        cluster_fig, partial_fig, embedding_fig,
):
    """
    Subplots of 2 rows and 2 columns.
    [cluster plot, partial plot  ]
    [cluster plot, embedding plot]

    """

    fig = make_subplots(
        cols=2, rows=2,
        column_widths=[0.5, 0.5], row_heights=[0.6, 0.4],
        horizontal_spacing=0.02, vertical_spacing=0.05,
        specs=[
            [{'rowspan': 2, "type": "scene"}, {"type": "scene"}],
            [None, {"type": "xy"}]
        ],
        subplot_titles=(
            cluster_fig.layout.title.text,
            partial_fig.layout.title.text,
            embedding_fig.layout.title.text)
    )

    fig.add_traces(cluster_fig.data, rows=1, cols=1)
    fig.add_traces(partial_fig.data, rows=1, cols=2)
    fig.add_traces(embedding_fig.data, rows=2, cols=2)

    fig.update_layout(cluster_fig.layout)
    fig.update_scenes(cluster_fig.layout.scene, row=1, col=1)
    fig.update_scenes(partial_fig.layout.scene, row=1, col=2)
    fig.update_xaxes(embedding_fig.layout.xaxis, row=2, col=2)
    fig.update_yaxes(embedding_fig.layout.yaxis, row=2, col=2)
    fig.update_layout(title_text='SPPDP Compass')

    return fig


def partial_compound_plot(
        geo_vector, temporal_vector, feature_partial,
        embedding, cluster_label,
        sample_size=None, quantile=None,
        feature_names=None, folder=default_folder,
):
    """
    Subplots of 2 rows and 2 columns.
    [cluster plot, partial plot  ]
    [cluster plot, embedding plot]

    One compound plot for each feature cluster result. Another compound plot for whole feature cluster result.

    Args:
        geo_vector ():
        temporal_vector ():
        feature_partial (): Shape(Feature, N, 2)
        embedding (): Shape(Feature, N, 2) for individual cluster, Shape(N, 2) for merged cluster.
        cluster_label (): Shape(Feature, N) for individual cluster, Shape(N,) for merged cluster.
        sample_size ():
        quantile:
        feature_names (): Shape(Feature)
        folder ():


    Returns:

    """

    # TODO: Add hover highlight.

    feature_count = len(feature_partial)

    partial_figs = partials_plot_3d(
        feature_partial, temporal_vector, cluster_label,
        sample_size=sample_size, quantile=quantile, feature_names=feature_names
    )

    if len(embedding.shape) == 2 and len(cluster_label.shape) == 1:
        embedding_fig = embedding_plot(
            embedding, cluster_label, temporal_vector,
            f'total features',
        )
        embedding_figs = [embedding_fig for _ in range(feature_count)]

        cluster_fig = scatter_3d(
            geo_vector, temporal_vector, cluster_label,
            f'Merged Spatio-temporal Cluster Plot', 'Cluster Label',
            filename=f'Cluster_Merged', is_cluster=True, folder=folder)
        cluster_figs = [cluster_fig for _ in range(feature_count)]
    else:
        embedding_figs = [
            embedding_plot(
                embedding[feature_index], cluster_label[feature_index], temporal_vector,
                f'Feature {feature_index + 1} {feature_names[feature_index] if feature_names is not None else ""}',
            )
            for feature_index in range(feature_count)
        ]

        cluster_figs = [
            scatter_3d(
                geo_vector, temporal_vector, cluster_label[feature_index],
                f'Spatio-temporal Cluster Plot of Feature {feature_index + 1} {feature_names[feature_index] if feature_names is not None else ""}',
                'Cluster Label',
                filename=f'Cluster_{feature_index + 1}', is_cluster=True, folder=folder)
            for feature_index in range(feature_count)
        ]

    compass_figs = [
        compass_plot(
            cluster_figs[feature_index], partial_figs[feature_index], embedding_figs[feature_index],
        )
        for feature_index in range(feature_count)
    ]

    return partial_figs, embedding_figs, cluster_figs, compass_figs


if __name__ == '__main__':
    sample_partial()


from pathlib import Path

import numpy as np
from plotly import express as px, graph_objects as go
from georegression.visualize.utils import vector_to_color

from georegression.visualize import default_folder


def scatter_3d(
        geo_vector, temporal_vector, value,
        figure_title, value_name, filename=None,
        is_cluster=False,
        folder=default_folder
):
    # Shape(N, )
    x = geo_vector[:, 0]
    y = geo_vector[:, 1]
    z = temporal_vector[:, 0]

    x_min = np.min(x)
    x_max = np.max(x)
    x_interval = x_max - x_min
    y_min = np.min(y)
    y_max = np.max(y)
    y_interval = y_max - y_min
    z_min = np.min(z)
    z_max = np.max(z)
    z_interval = z_max - z_min
    z_unique = np.unique(z)
    z_step = z_interval / len(z_unique)

    value = np.array(value)
    count = value.shape[0]
    value_min = np.min(value)
    value_max = np.max(value)
    value_interval = value_max - value_min

    # Index for each point
    custom_data = np.arange(count)

    # Multiple legend for cluster input.
    if is_cluster:
        # Quick way using express.
        # Maybe use the Graph Object to unify the style?
        # fig = px.scatter_3d(x=x, y=y, z=z, color=value.astype('str'))

        fig = go.Figure()
        color = vector_to_color(value)

        for cluster_value in np.unique(value):
            cluster_index = value == cluster_value
            fig.add_trace(
                go.Scatter3d(
                    x=x[cluster_index], y=y[cluster_index], z=z[cluster_index], mode='markers',
                    # Name of trace for legend display
                    name=f'Cluster {cluster_value}',
                    legendgroup=f'Cluster {cluster_value}',
                    marker={
                        'color': color[cluster_index],
                        'size': 5,
                    },
                    text=value[cluster_index],
                    customdata=custom_data,
                    hovertemplate=
                    f'<b>Time Slice</b> :' + ' %{z} <br />' +
                    f'<b>Index</b> :' + ' %{customdata} <br />' +
                    f'<b>{value_name}</b> :' + ' %{text} <br />' +
                    '<extra></extra>',
                )
            )
    # Continuous value case. Single legend/trace.
    else:
        tick_value = np.quantile(value, [0, 0.25, 0.5, 0.75, 1], interpolation='nearest')
        fig = go.Figure(data=[
            # Data Point
            go.Scatter3d(
                x=x, y=y, z=z, mode='markers',
                # Name of trace for legend display
                name=f'{value_name}',
                marker={
                    'size': 5,
                    'color': value,
                    'colorscale': 'Portland',
                    # Input dict of properties to construct the ColorBar Instance
                    'colorbar': {
                        'x': 0.8,
                        'title': f'{value_name} Color Bar<br>(Quartile tick)<br> <br>',
                        'tickvals': tick_value,
                        'tickformat': '.3~f',
                    },
                },
                text=value,
                customdata=custom_data,
                hovertemplate=
                '<b>Time Slice</b> :' + ' %{z} <br />' +
                f'<b>Index</b> :' + ' %{customdata} <br />' +
                f'<b>{value_name}</b> :' + ' %{text:.3~f} <br />' +
                '<extra></extra>'
            ),
        ])

    # TODO: Add Joint Line

    # Time Surface
    x_surface = (x_min, x_min, x_max, x_max)
    y_surface = (y_min, y_max, y_max, y_min)
    z_shift = z_step * 0.08
    surface_color = vector_to_color(np.unique(z))

    fig.add_traces([
        go.Mesh3d(
            x=x_surface, y=y_surface,
            # Shift the surface down a little to avoid overlay
            z=[z - z_shift] * 4, opacity=0.3,
            color=surface_color[z_index],
            hoverinfo='skip',
            name=f'Auxiliary Surface Group',
            legendgroup='Auxiliary Surface Group',
            showlegend=True if not z_index else False
        )
        for z_index, z in enumerate(z_unique)
    ])

    # Set figure, axis and other things.

    if x_interval < y_interval:
        x_aspect = 1
        y_aspect = y_interval / x_interval
        z_aspect = y_aspect * len(z_unique) * 0.6
    else:
        y_aspect = 1
        x_aspect = x_interval / y_interval
        z_aspect = x_aspect * len(z_unique) * 0.6

    fig.update_layout(
        # Clear margin
        margin=dict(l=0, r=0, t=50, b=0, pad=0),

        # Figure title
        title={
            'text': figure_title,
            'xanchor': 'center',
            'x': 0.45,
            'yanchor': 'top',
            'y': 0.99,
        },

        # Global font
        font=dict(size=12),

        # Legend
        legend_title="Point and Surface Legend",

        template="seaborn",
        font_family="Times New Roman"
    )

    fig.update_scenes(
        # Change Projection to Orthogonal
        camera_projection_type="orthographic",

        # Set axis ratio
        aspectmode='manual',
        aspectratio=dict(x=x_aspect, y=y_aspect, z=z_aspect),

        # Axis label
        xaxis=dict(
            title='X Position',
            ticktext=['Neg', 'Pos'],
            tickvals=[x_min, x_max],
            range=[x_min - x_interval * 0.12, x_max + x_interval * 0.12],
            # backgroundcolor="rgb(200, 200, 230)",
            # gridcolor="white",
            showbackground=True,
            # zerolinecolor="white",
            showspikes=False
        ),
        yaxis=dict(
            title='Y Position',
            ticktext=['Neg', 'Pos'],
            tickvals=[y_min, y_max],
            range=[y_min - y_interval * 0.12, y_max + y_interval * 0.12],
            # backgroundcolor="rgb(230, 200,230)",
            # gridcolor="white",
            showbackground=True,
            # zerolinecolor="white",
            showspikes=False
        ),
        zaxis=dict(
            title='Temporal Slice Index',
            tickvals=z_unique,
            range=[z_min - z_step * 0.5, z_max + z_step * 0.5],
            # backgroundcolor="rgb(230, 230,200)",
            # gridcolor="white",
            showbackground=True,
            # zerolinecolor="white",
            # showspikes=False
        ),
    )

    # Output to disk file
    if filename is None:
        filename = f'{figure_title}_{value_name}'

    fig.write_html(folder / f'{filename}.html')

    return fig


import numpy as np
from matplotlib import colors, cm
from plotly import graph_objects as go


def color_to_str(color_vector):
    def inner_stringify(color):
        return np.array(
            f'rgba({int(color[0] * 255)},{int(color[1] * 255)},{int(color[2] * 255)},{int(color[3] * 255)})',
            dtype=object
        )

    return np.apply_along_axis(inner_stringify, -1, color_vector)


def vector_to_color(vector, stringify=True, colormap='viridis'):
    """
    Set style template will change the default colormap.

    Args:
        vector ():
        stringify ():
        colormap ():

    Returns:

    """
    vector_normed = (vector - np.min(vector, axis=-1, keepdims=True)) / (
            np.max(vector, axis=-1, keepdims=True) - np.min(vector, axis=-1, keepdims=True))
    color = cm.get_cmap(colormap)(vector_normed)
    if not stringify:
        return color
    return color_to_str(color)


def range_margin(vector=None, value_min=None, value_max=None, margin=0.05):
    if vector is None and (value_min is None or value_max is None):
        raise Exception('Invalid parameter')

    if vector is not None:
        value_min = np.min(vector)
        value_max = np.max(vector)

    interval = value_max - value_min

    return [value_min - interval * margin, value_max + interval * margin]


def aspect_ratio_scale(fig: go.Figure, scale: float = 0.2):
    x_aspect = fig.layout.scene.aspectratio.x
    y_aspect = fig.layout.scene.aspectratio.y
    z_aspect = fig.layout.scene.aspectratio.z

    fig.update_layout(
        scene_aspectratio={
            "x": x_aspect * scale,
            "y": y_aspect * scale,
            "z": z_aspect * scale,
        },
    )

    return fig


def camera_position(fig: go.Figure, rotation=-70, pitch=40):
    # From rotation to x y using sin and cos
    x = np.cos(rotation * np.pi / 180)
    y = np.sin(rotation * np.pi / 180)

    # From pitch to z using sin
    z = np.sin(pitch * np.pi / 180)
    camera = dict(
        eye=dict(x=x, y=y, z=z)
    )
    fig.update_layout(scene_camera=camera)

    return fig


from pathlib import Path
import matplotlib.pyplot as plt

# Create folder for saving plot
default_folder = Path('Plot')
default_folder.mkdir(exist_ok=True)

# Set default style for matplotlib

# Control style to make it consistent with the Plotly
# plt.style.use('seaborn')
plt.style.use('seaborn-v0_8')
# Font family
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 12
plt.rcParams["axes.labelsize"] = 12
