from abc import ABC, abstractmethod

import numpy as np
import scipy
from calibration._linalg import gram_schmidt_orthonormalization
from scipy.special import expit
from sklearn.utils import check_random_state


def slope(x, theta=np.pi/4, t=0.5):
    return np.clip(np.tan(theta)*(x - t) + t, 0, 1)


class BaseExample(ABC):

    def __init__(self, threshold=0.5) -> None:
        super().__init__()
        self.t = threshold

    @abstractmethod
    def Q(self, X):
        """The true posterior probabilities of the positive class."""
        pass

    @abstractmethod
    def S(self, X):
        """The confidence scores of the positive class."""
        pass

    @abstractmethod
    def C(self, X):
        """The calibrated scores of the positive class."""
        pass

    @abstractmethod
    def dist(self):
        pass

    def generate_X_y(self, n, random_state=0):
        rng = check_random_state(random_state)
        dist = self.dist()
        X = dist.rvs(size=n, random_state=rng)
        y = rng.binomial(1, self.Q(X))

        # Turn X 2d if not already
        if X.ndim == 1:
            X = X.reshape(-1, 1)

        return X, y

    def acc_bayes(self, return_err=False):
        def d(x):
            Q = self.Q(x)
            return Q*(Q >= self.t) + (1 - Q)*(Q < self.t)

        pdf = self.dist().pdf
        acc, abserr = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        if return_err:
            return acc, abserr

        return acc

    def acc_s(self, return_err=False):
        def d(x):
            S = self.S(x)
            Q = self.Q(x)
            return Q*(S >= self.t) + (1 - Q)*(S < self.t)

        pdf = self.dist().pdf
        acc, abserr = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        if return_err:
            return acc, abserr

        return acc

    def acc_s_wrt_q(self, return_err=False):
        def d(x):
            Q = self.Q(x)
            S = self.S(x)
            return (S >= self.t)*(Q >= self.t) + (S < self.t)*(Q < self.t)

        pdf = self.dist().pdf
        acc, abserr = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        if return_err:
            return acc, abserr

        return acc

    def ER(self, return_err=False):
        """Bayes error rate."""
        acc, abserr = self.acc_bayes(True)
        ER = 1 - acc

        if return_err:
            return ER, abserr

        return ER

    def GL(self, return_err=False):
        assert isinstance(self, Base1DExample)  # Implemented for 1D only
        def d(x):
            # GL = MSE(C, Q) with brier score as proper scoring rule
            return np.square(self.C(x) - self.Q(x))

        pdf = self.dist().pdf
        GL, abserr = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        if return_err:
            return GL, abserr

        return GL

    def _base_emp(self, f, N=100000, random_state=0):
        """Base code for empirical estimation."""
        rng = check_random_state(random_state)
        dist = self.dist()
        X = dist.rvs(size=N, random_state=rng)
        return np.mean(f(X))

    def GL_emp(self, N=100000, random_state=0):
        """Empirical estimation of grouping loss."""
        def d(x):
            # GL = MSE(C, Q) with brier score as proper scoring rule
            return np.square(self.C(x) - self.Q(x))

        return self._base_emp(d, N, random_state)

    def UB_known(self, return_err=False):
        """Grouping loss upper bound C(1 - C). S is assumed to be calibrated."""
        def d(x):
            C = self.C(x)
            return C*(1 - C)

        pdf = self.dist().pdf
        UB, abserr = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        if return_err:
            return UB, abserr

        return UB

    def UB_ER(self, return_err=False):
        """Grouping loss upper bound with Bayes error rate."""
        UB_known, abserr1 = self.UB_known(True)
        ER, abserr2 = self.ER(return_err=True)
        abserr = abserr1 + 0.5*abserr2

        UB = UB_known - 0.5*ER

        if return_err:
            return UB, abserr

        return UB

    def UB_acc_sq1(self, return_err=False):
        """Grouping loss upper bound with Acc(S||Q)."""
        acc, abserr1 = self.acc_s_wrt_q(return_err=True)

        def d(x):
            S = self.S(x)
            C = self.C(x)
            return C*(0.5 - C)*(S < 0.5) + (C - 0.5)*(1 - C)*(S >= 0.5)

        pdf = self.dist().pdf
        UB, abserr2 = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        abserr = abserr1 + 0.5*abserr2
        UB += 0.5*(1 - acc)

        if return_err:
            return UB, abserr

        return UB

    def UB_acc_sq2(self, return_err=False):
        """Grouping loss upper bound with Acc(S||Q)."""
        acc, abserr1 = self.acc_s_wrt_q(return_err=True)

        def d(x):
            S = self.S(x)
            C = self.C(x)
            return (C - 0.5)*(1 - C)*(S < 0.5) + (0.5 - C)*C*(S >= 0.5)

        pdf = self.dist().pdf
        UB, abserr2 = scipy.integrate.quad(lambda x: d(x)*pdf(x), self.x_min, self.x_max)

        abserr = abserr1 + 0.5*abserr2
        UB += 0.5*acc

        if return_err:
            return UB, abserr

        return UB

    def UB_acc_sq(self, return_err=False):
        """Grouping loss upper bound with Acc(S||Q)."""
        acc_s_wrt_q, abserr1 = self.acc_s_wrt_q(return_err=True)
        acc_s, abserr2 = self.acc_s(return_err=True)
        UB, abserr3 = self.UB_known(return_err=True)

        UB -= 0.5*np.abs(acc_s_wrt_q - acc_s)

        if return_err:
            abserr = abserr3 + 0.5*(abserr1 + abserr2)
            return UB, abserr

        return UB


def get_dist_1d(name, x_min=None, x_max=None):
    if name == 'uniform':
        dist = scipy.stats.uniform(loc=x_min, scale=(x_max-x_min))

    elif name == 'gaussian':
        dist = scipy.stats.norm()

    else:
        raise ValueError(f'Unsupported dist "{name}"')

    return dist


class Base1DExample(BaseExample, ABC):

    def __init__(self, threshold=0.5, dist='uniform') -> None:
        super().__init__(threshold)
        self._dist = dist

        if dist == 'uniform':
            x_min, x_max = -1, 1
        elif dist == 'gaussian':
            x_min, x_max = -10, 10

        self.x_min = x_min
        self.x_max = x_max

    def dist(self):
        return get_dist_1d(self._dist, self.x_min, self.x_max)

    def Q(self, x):
        S = self._S(x)
        idx_pos = x > 0
        idx_neg = x < 0
        idx_nul = x == 0

        q = np.full_like(x, np.nan)
        q[idx_pos] = self.h(S[idx_pos])
        q[idx_neg] = self.g(S[idx_neg])
        q[idx_nul] = S[idx_nul]
        return q

    def C(self, x):
        S = self._S(x)
        return (x != 0)*(self.h(S) + self.g(S))/2 + (x == 0)*S

    @abstractmethod
    def h(self, x):
        pass

    @abstractmethod
    def g(self, x):
        pass

    @abstractmethod
    def _S(self, x):
        pass

    def S(self, x):
        x = np.squeeze(x)
        return self._S(x)


class Link1DExample(Base1DExample):

    def __init__(self, threshold=0.5, alpha=1, link='sin', s='bowl', dist='uniform'):
        super().__init__(threshold, dist)
        self.alpha = alpha
        self.link = link
        self.s = s

    def h(self, s):
        if self.link == 'sin':
            return self.alpha/np.pi*np.sin(np.pi*s) + s

        if self.link == 'sin2':
            return self.alpha/(2*np.pi)*np.sin(2*np.pi*s) + s

        if self.link == 'poly':
            return self.alpha*(-np.square(s) + s) + s

    def g(self, s):
        return 2*s - self.h(s)

    def _S(self, x):
        if self.s == 'bowl':
            return 1 - np.exp(-np.square(1.8*x))

        elif self.s == 'sin':
            return 0.5*(np.sin(10*np.pi*x) + 1)

        else:
            raise ValueError(f'Unknown choice {self.s} for s.')

    def __repr__(self) -> str:
        kwargs = {
            'alpha': self.alpha,
            'link': self.link,
            's': self.s,
            'dist': self._dist,
        }
        return self.__class__.__name__+'('+','.join([f'{k}={v}' for k, v in kwargs.items()])+')'


class Slope1DExample(Base1DExample):

    def __init__(self, threshold=0.5, alpha=1, dist='uniform'):
        super().__init__(threshold, dist)
        self.alpha = alpha

    def h(self, s):
        return slope(s, self.alpha)

    def g(self, s):
        return 0.5

    def _S(self, x):
        return 1 - np.exp(-np.square(1.8*x))


class Slopes1DExample(Base1DExample):

    def __init__(self, threshold=0.5, alpha=1, dist='uniform'):
        super().__init__(threshold, dist)
        self.alpha = alpha

    def h(self, s):
        return slope(s, self.alpha)

    def g(self, s):
        return slope(s, -self.alpha)

    def _S(self, x):
        return 1 - np.exp(-np.square(1.8*x))


class Squares1DExample(Base1DExample):

    def __init__(self, threshold=0.5, dist='uniform'):
        super().__init__(threshold, dist)

    def h(self, s):
        return 0.75*(s < self.t) + 1*(s >= self.t)

    def g(self, s):
        return s >= self.t

    def _S(self, x):
        return 1 - np.exp(-np.square(1.8*x))


class SigmoidExample(BaseExample):
    def __init__(self, w, w_perp, bayes_opt=False, alpha=1, lbd=10, psi='expit'):
        self.w = np.squeeze(w)
        self.w_perp = np.squeeze(w_perp)
        self.bayes_opt = bayes_opt
        self.alpha = alpha
        self.lbd = lbd
        self._psi = psi

        d = self.w.shape[0]
        P = np.eye(d)
        P[:, 0] = self.w
        P[:, 1] = self.w_perp
        P = gram_schmidt_orthonormalization(P)
        D = np.eye(d)
        D[0, 0] = lbd
        # Create PSD cov from PDP^-1 decomposition
        self.cov = P @ D @ np.linalg.inv(P)
        self.mean = np.zeros_like(self.w)

    def dist(self):
        return scipy.stats.multivariate_normal(mean=self.mean, cov=self.cov)

    def phi(self, dot):
        return expit(dot)

    def S(self, X):
        dot = np.dot(X, self.w)
        return self.phi(dot)

    def psi(self, dot_perp):
        if self._psi == 'expit':
            y = 2*expit(self.alpha*dot_perp) - 1
            return np.sign(y)*np.abs(y)  # ?

        elif self._psi == 'sin':
            return np.sin(self.alpha*np.pi*dot_perp)

        elif self._psi == 'signsin':
            return np.sign(np.sin(self.alpha*np.pi*dot_perp))

        elif self._psi == 'step':
            y = np.array(dot_perp > 0).astype(float) - np.array(dot_perp < 0).astype(float)
            return y

        else:
            raise ValueError(f'Unknown choice {self._psi} for psi.')

    def delta_max(self, dot):
        _delta_max = np.minimum(1 - self.phi(dot), self.phi(dot))
        if self.bayes_opt:
            _delta_max = np.minimum(_delta_max, np.abs(self.phi(dot) - 1/2))
        return _delta_max

    def _delta(self, dot, dot_perp):
        _delta_max = self.delta_max(dot)
        return np.multiply(self.psi(dot_perp), _delta_max)

    def delta(self, X):
        dot_perp = np.dot(X, self.w_perp)
        dot = np.dot(X, self.w)
        return self._delta(dot, dot_perp)

    def C(self, X):
        return self.S(X)

    def Q(self, X):
        return self.S(X) + self.delta(X)

    def __repr__(self) -> str:
        kwargs = {
            'alpha': self.alpha,
            'lbd': self.lbd,
            'psi': self._psi,
        }
        return self.__class__.__name__+'('+','.join([f'{k}={v}' for k, v in kwargs.items()])+')'
