from typing import Callable

import numpy as np
from tqdm import tqdm


def _outer(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return a[..., np.newaxis] * b[..., np.newaxis, :]


class SVGD:
    def __init__(
        self,
        kernel: Callable[[np.ndarray], tuple[np.ndarray, np.ndarray, float]],
        score_fn: Callable[[np.ndarray], np.ndarray],
        x0: np.ndarray,
    ) -> None:
        self.kernel = kernel
        self.score_fn = score_fn
        self.x0 = np.copy(x0)
        self.x = np.copy(x0)
        self.N, self.d = x0.shape

    def update(
        self,
        n_steps: int = 1000,
        step_size: float = 1e-2,
        show_progress: bool = True,
        return_history: bool = False,
        hist_freq: int = 1,
        use_adadelta: bool = True,
        progress_desc: str | None = None,
        progress_bar_position: int = 0,
        tol: float = 1e-6,
        h_dim: int = 1,
    ) -> tuple:
        if use_adadelta:
            alpha = 0.9
            historical_grad = np.zeros_like(self.x)
        if return_history:
            num_save_steps = (n_steps - 1) // hist_freq + 1
            x_hist = np.empty((num_save_steps, self.N, self.d))
            if h_dim > 1:
                h_hist = np.empty((num_save_steps, h_dim))
            else:
                h_hist = np.empty(num_save_steps)
        for step in tqdm(
            range(n_steps),
            disable=not show_progress,
            desc=progress_desc,
            position=progress_bar_position,
        ):
            kernel, grad_kernel, h = self.kernel(self.x)
            if return_history and step % hist_freq == 0:
                x_hist[step // hist_freq] = self.x
                h_hist[step // hist_freq] = h
            neg_grad = (kernel @ self.score_fn(self.x) + grad_kernel) / self.N
            if use_adadelta:
                # Adadelta optimisation (https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/master/python/svgd.py)
                if step == 0:
                    historical_grad = neg_grad**2
                else:
                    historical_grad = (
                        alpha * historical_grad + (1 - alpha) * neg_grad**2
                    )
                neg_grad /= np.sqrt(historical_grad) + 1e-6
            if np.sum(np.linalg.norm(neg_grad, axis=1)) < tol:
                if return_history:
                    x_hist = x_hist[: step // hist_freq + 1]
                    h_hist = h_hist[: step // hist_freq + 1]
                break
            self.x += step_size * neg_grad
        if return_history:
            return x_hist, h_hist
        return ()

    def reset(self, x0: np.ndarray | None = None) -> None:
        if x0 is None:
            self.x = np.copy(self.x0)
        else:
            self.x = np.copy(x0)
        self.N, self.d = self.x.shape


class SVGDAdaptive:
    """adaptive SVGD with product kernel exp(-sum_i(|x_i - y_i|^p / h_i))"""

    def __init__(
        self,
        h: np.ndarray,
        score_fn: Callable[[np.ndarray], np.ndarray],
        x0: np.ndarray,
        min_h: float = 1e-6,
        max_h: float = 1e3,
        p: float = 2.0,
    ) -> None:
        self.h0 = np.copy(h)
        self.h = h
        self.x0 = np.copy(x0)
        self.x = np.copy(x0)
        self.score_fn = score_fn
        self.N, self.d = x0.shape
        self.min_h = min_h
        self.max_h = max_h
        self.p = p
        assert self.h.shape == (self.d,)

    def update(
        self,
        n_steps: int = 1000,
        step_size: float = 1e-2,
        show_progress: bool = True,
        return_history: bool = False,
        hist_freq: int = 1,
        use_adadelta: bool = False,
        adaptive_h: bool = True,
        progress_desc: str | None = None,
        progress_bar_position: int = 0,
        h_steps_per_x_step: int = 1,
        h_step_size: float = 1.0,
        tol: float = 1e-6,
    ) -> tuple:
        if use_adadelta:
            alpha = 0.9
            historical_grad = np.zeros_like(self.x)
        if return_history:
            num_save_steps = (n_steps - 1) // hist_freq + 1
            x_hist = np.empty((num_save_steps, self.N, self.d))
            h_hist = np.empty((num_save_steps, self.d))
        diag_helper = np.zeros((self.N, self.N, self.d, self.d))
        mask = ~np.eye(self.N, dtype=bool)[..., np.newaxis]
        for step in tqdm(
            range(n_steps),
            disable=not show_progress,
            desc=progress_desc,
            position=progress_bar_position,
        ):
            if return_history and step % hist_freq == 0:
                x_hist[step // hist_freq] = self.x.copy()
                h_hist[step // hist_freq] = self.h.copy()
            # preparations
            pairwise_diff = self.x - self.x[:, np.newaxis, :]
            sgn_pairwise_diff = np.sign(pairwise_diff)
            abs_pairwise_diff = np.abs(pairwise_diff)
            grad_log_likelihood = self.score_fn(self.x)
            abs_pow_p_minus_2 = np.power(
                abs_pairwise_diff,
                self.p - 2,
                out=np.zeros_like(abs_pairwise_diff),
                where=abs_pairwise_diff != 0,
            )
            abs_pow_p_minus_1 = abs_pow_p_minus_2 * abs_pairwise_diff
            abs_pow_p = abs_pow_p_minus_1 * abs_pairwise_diff

            if adaptive_h:
                if step % h_steps_per_x_step == 0:
                    kernel: np.ndarray = np.exp(-np.sum(abs_pow_p / self.h, axis=2))
                    grad_h_kernel = kernel[..., np.newaxis] * abs_pow_p / self.h**2
                    dp_score_fn = np.tensordot(
                        grad_log_likelihood, grad_log_likelihood, axes=(1, 1)
                    )
                    diag_helper[..., np.arange(self.d), np.arange(self.d)] = (
                        sgn_pairwise_diff * abs_pow_p_minus_1 / self.h**2
                    )
                    grad_h_grad_kernel = (
                        kernel[..., np.newaxis, np.newaxis]
                        * self.p
                        * (
                            _outer(
                                abs_pow_p / self.h**2,
                                sgn_pairwise_diff * abs_pow_p_minus_1 / self.h,
                            )
                            - diag_helper
                        )
                    )
                    grad_h_trace = grad_h_kernel * self.p * np.sum(
                        (self.p - 1) * abs_pow_p_minus_2 / self.h
                        - self.p * (abs_pow_p_minus_1 / self.h) ** 2,
                        axis=2,
                    )[..., np.newaxis] + kernel[..., np.newaxis] * self.p * (
                        2 * self.p * abs_pow_p_minus_1**2 / self.h**3
                        - (self.p - 1) * abs_pow_p_minus_2 / self.h**2
                    )

                    grad_h_ksd = np.mean(
                        grad_h_kernel * dp_score_fn[..., np.newaxis]
                        + 2
                        * np.squeeze(
                            grad_h_grad_kernel
                            @ grad_log_likelihood[np.newaxis, ..., np.newaxis],
                            axis=-1,
                        )
                        + grad_h_trace,
                        axis=(0, 1),
                        where=mask,
                    )
                    self.h += h_step_size * grad_h_ksd
                    self.h = np.clip(self.h, self.min_h, self.max_h)
            else:
                self.h = np.array(
                    [
                        np.median(np.sum(abs_pow_p, axis=2)) / np.log(self.N)
                        for i in range(self.d)
                    ]
                )

            # update particles
            kernel: np.ndarray = np.exp(-np.sum(abs_pow_p / self.h, axis=2))
            psi_opt = np.mean(
                kernel[..., np.newaxis]
                * (
                    grad_log_likelihood[:, np.newaxis, :]
                    + self.p * sgn_pairwise_diff * abs_pow_p_minus_1 / self.h
                ),
                axis=0,
            )
            if use_adadelta:
                # Adadelta optimisation (https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/master/python/svgd.py)
                if step == 0:
                    historical_grad = psi_opt**2
                else:
                    historical_grad = alpha * historical_grad + (1 - alpha) * psi_opt**2
                psi_opt /= np.sqrt(historical_grad) + 1e-6
            if np.sum(np.linalg.norm(psi_opt, axis=1)) < tol:
                if return_history:
                    x_hist = x_hist[: step // hist_freq + 1]
                    h_hist = h_hist[: step // hist_freq + 1]
                break
            self.x += step_size * psi_opt
        if return_history:
            return x_hist, h_hist
        return ()

    def reset(self, x0: np.ndarray | None = None) -> None:
        if x0 is None:
            self.x = np.copy(self.x0)
        else:
            self.x = np.copy(x0)
        self.h = np.copy(self.h0)
        self.N, self.d = self.x.shape
