from typing import List

import numpy as np
from scipy.stats import norm

from .utils import (QuantileFunction, SecondQuantileFunction,
                    bootstrap_multiprocessing)


class StochasticOrderTesting:

    def __init__(self,
                 scores_list: List[np.ndarray],
                 n_bootstrap: int = 1000,
                 num_workers: int = 4,
                 dp: float = 0.001,
                 verbose=True) -> None:
        r"""Absolute and relative Stochastic Order Testing
        
        Args:
            scores_list (List[np.ndarray]): List of scores (1D arrays) to compare
            n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.
            num_workers (int, optional): Number of workers for parallelization. Defaults to 4.
            dp (float, optional): Step size for quantile functions. Defaults to 0.001.
        """

        # Check and cache inputs
        if not isinstance(scores_list, list):
            raise ValueError(
                "Scores should be a list of 1D arrays (one per model)")
        self.scores_list = scores_list
        self.n_bootstrap = n_bootstrap
        self.num_workers = num_workers
        self.dp = dp
        self.verbose = verbose
        self.k = len(scores_list)

        self.n_samples = len(scores_list[0])
        for scores in self.scores_list:
            self.n_samples = min(self.n_samples, len(scores))
            if scores.ndim != 1:
                raise ValueError("Scores should be a 1D array")

    def _compute_bootstrap(self) -> None:
        # Cache bootstraps and statistics
        p = np.arange(self.dp, 1, self.dp)  # vector of len m = 1 / dp - 1

        get_qs = lambda scores: QuantileFunction(scores)(p)
        get_iqs = lambda scores: SecondQuantileFunction(scores)(p)

        self.qs = np.r_[[get_qs(scores)
                         for scores in self.scores_list]]  # k x m
        self.iqs = np.r_[[get_iqs(scores)
                          for scores in self.scores_list]]  # k x m

        # Computes bootstrapts in parallel
        def get_qs_b(scores, seed):
            rng = np.random.default_rng(seed)
            rand_scores = rng.choice(scores, self.n_samples)
            return QuantileFunction(rand_scores)(p)

        def get_iqs_b(scores, seed):
            rng = np.random.default_rng(seed)
            rand_scores = rng.choice(scores, self.n_samples)
            return SecondQuantileFunction(rand_scores)(p)

        # Compute quantiles
        qs_fn = lambda seed: np.r_[
            [get_qs_b(scores, seed) for scores in self.scores_list]]

        self.qs_b = np.r_[bootstrap_multiprocessing(
            qs_fn,
            num_workers=self.num_workers,
            n_bootstrap=self.n_bootstrap,
            desc="bootstrap quantiles",
            verbose=self.verbose)]  # B x k x m

        # Compute integrated quantiles
        iqs_fn = lambda seed: np.r_[
            [get_iqs_b(scores, seed) for scores in self.scores_list]]

        self.iqs_b = np.r_[bootstrap_multiprocessing(
            iqs_fn,
            num_workers=self.num_workers,
            n_bootstrap=self.n_bootstrap,
            desc="bootstrap second quantiles",
            verbose=self.verbose)]  # B x k x m

    def _compute_violations(self) -> None:

        def get_violations_stats(qs):
            eps = np.zeros((len(qs), len(qs)))
            for i in range(len(qs)):
                for j in range(len(qs)):
                    if i == j:
                        eps[i, j] = 0.0
                    else:
                        diff = qs[j] - qs[i]
                        sq_wd = np.sum(diff**2 *
                                       self.dp)  # sq_wasserstein_dist
                        if sq_wd > 0:
                            eps[i, j] = np.sum(
                                diff[diff > 0]**2 * self.dp) / sq_wd
                        else:
                            eps[i, j] = 0.5
            return eps

        # Compute violations statistics for original data
        k_m = self.k - 1
        self.eps_qs = get_violations_stats(self.qs)  # k x k
        self.eps_i_qs = self.eps_qs.sum(axis=1, keepdims=True) / k_m  # k x 1

        self.eps_iqs = get_violations_stats(self.iqs)  # k x k
        self.eps_i_iqs = self.eps_iqs.sum(axis=1, keepdims=True) / k_m  # kx1

        # Compute violations statistics for all bootstraps
        eps_qs_fn = lambda seed: get_violations_stats(self.qs_b[seed])
        eps_qs_b = np.c_[bootstrap_multiprocessing(
            eps_qs_fn,
            num_workers=self.num_workers,
            n_bootstrap=self.n_bootstrap,
            desc="bootstrap quantiles violations",
            verbose=self.verbose)]  # B x k x k

        eps_iqs_fn = lambda seed: get_violations_stats(self.iqs_b[seed])
        eps_iqs_b = np.c_[bootstrap_multiprocessing(
            eps_iqs_fn,
            num_workers=self.num_workers,
            n_bootstrap=self.n_bootstrap,
            desc="bootstrap second quantiles violations",
            verbose=self.verbose)]  # B x k x k

        # Compute bootstrap variance for absolute test
        self.sigma_abs_qs = np.std(eps_qs_b, ddof=1, axis=0)  # k x k
        self.sigma_abs_iqs = np.std(eps_iqs_b, ddof=1, axis=0)  # k x k

        # Compute bootstratp variances for relative test
        eps_qs_i_b = eps_qs_b.sum(axis=2, keepdims=True) / k_m  # B x k x 1
        eps_qs_i_b = eps_qs_i_b.swapaxes(0, 1)  # k x B x 1
        self.sigma_rel_qs = np.std(eps_qs_i_b - eps_qs_i_b.T, ddof=1,
                                   axis=1)  # k x k

        eps_iqs_i_b = eps_iqs_b.sum(axis=2, keepdims=True) / k_m  # B x k x 1
        eps_iqs_i_b = eps_iqs_i_b.swapaxes(0, 1)  # k x B x 1
        self.sigma_rel_iqs = np.std(eps_iqs_i_b - eps_iqs_i_b.T,
                                    ddof=1,
                                    axis=1)  # k x k

    def _get_wins(self,
                  eps_0: np.ndarray,
                  sigma: np.ndarray,
                  alpha: float,
                  tau: float = 0.0) -> np.ndarray:
        phi = norm.ppf(alpha / self.k**2)
        th = 1 / np.sqrt(self.n_samples) * sigma * phi + tau
        wins = (eps_0 <= th).astype(int)
        return wins

    def compute_relative_test(self, alpha: float = 0.05, return_wins=False):
        """Compute relative test

        Args:
            alpha (float, optional): Significance level. Defaults to 0.05.
            return_wins (bool, optional): Whether to return wins instead of ranks. Defaults to False.

        Returns:
            Tuple[np.ndarray, np.ndarray]: QS and IQS ranks

        Example:
            >>> from soe.testing import RelativeStochasticOrderTesting
            >>> means = np.random.permutation(15)
            >>> scores_list = [m + np.random.randn(100) for m in means]
            >>> rel_test = RelativeStochasticOrderTesting(scores_list, n_bootstrap=100)
            >>> rank_qs, rank_iqs = rel_test.compute_relative_test(alpha=0.05)
        """
        if not hasattr(self, "qs_b"):
            self._compute_bootstrap()
        if not hasattr(self, "eps_qs"):
            self._compute_violations()

        eps_qs_0 = self.eps_i_qs - self.eps_i_qs.T
        wins_qs = self._get_wins(eps_qs_0, self.sigma_rel_qs, alpha)

        eps_iqs_0 = self.eps_i_iqs - self.eps_i_iqs.T
        wins_iqs = self._get_wins(eps_iqs_0, self.sigma_rel_iqs, alpha)

        if return_wins:
            return wins_qs.sum(axis=1), wins_iqs.sum(axis=1)
        else:
            # Compute ranks using Borda
            rank_qs = borda(wins_qs)
            rank_iqs = borda(wins_iqs)
            return rank_qs, rank_iqs

    def compute_absolute_test(self,
                              alpha: float = 0.05,
                              tau: float = 0.25,
                              return_wins=False):
        """Compute absolute test

        Args:
            alpha (float, optional): Significance level. Defaults to 0.05.
            tau (float, optional): Statistical threshold. Defaults to 0.25.
            return_wins (bool, optional): Whether to return wins instead of ranks. Defaults to False.

        Returns:
            Tuple[np.ndarray, np.ndarray]: QS and IQS ranks

        Example:
            >>> from soe.testing import RelativeStochasticOrderTesting
            >>> means = np.random.permutation(15)
            >>> scores_list = [m + np.random.randn(100) for m in means]
            >>> rel_test = RelativeStochasticOrderTesting(scores_list, n_bootstrap=100)
            >>> rank_qs, rank_iqs = rel_test.compute_absolute_test(tau=0.25)
        """
        if not hasattr(self, "qs_b"):
            self._compute_bootstrap()
        if not hasattr(self, "eps_qs"):
            self._compute_violations()

        wins_qs = self._get_wins(self.eps_qs, self.sigma_abs_qs, alpha, tau)
        wins_iqs = self._get_wins(self.eps_iqs, self.sigma_abs_iqs, alpha, tau)

        if return_wins:
            return wins_qs.sum(axis=1), wins_iqs.sum(axis=1)
        else:
            # Compute ranks using Borda
            rank_qs = borda(wins_qs)
            rank_iqs = borda(wins_iqs)
            return rank_qs, rank_iqs


def borda(wins: np.ndarray):
    """Borda count"""
    noise = 0.01 * np.random.rand(wins.shape[0])  # break ties
    return np.argsort(wins.sum(axis=1) + noise)[::-1]
