class CrMinusTwo:
    @staticmethod
    def estimate(datagen, wmin, wmax, rmin=0, rmax=1, raiseonerr=False, censored=False):
        from math import inf

        n, sumw, sumwsq, sumwr, sumwsqr, sumwany, sumwsqany = 0, 0, 0, 0, 0, 0, 0
        for c, w, r in datagen():
            n += c
            sumw += c*w
            sumwsq += c*w*w
            if r is not None:
                sumwany += c*w
                sumwsqany += c*w*w
                sumwr += c*w*r
                sumwsqr += c*w*w*r

        assert n > 0

        return CrMinusTwo.estimateimpl(
                n, sumw, sumwsq, sumwr, sumwsqr, sumwany, sumwsqany,
                wmin, wmax, rmin, rmax, raiseonerr, censored)

    def estimateimpl(n, sumw, sumwsq, sumwr, sumwsqr, sumwany, sumwsqany,
                     wmin, wmax, rmin=0, rmax=1, raiseonerr=False, censored=False):
        from math import inf

        assert wmin >= 0
        assert wmin < 1
        assert wmax > 1
        assert rmax >= rmin

        wfake = wmax if sumw < n else wmin

        if wfake == inf:
            gammastar = -(1 + n) / n
            betastar = 0
            gstar = 1 + 1 / n
        else:
            a = (wfake + sumw) / (1 + n)
            b = (wfake**2 + sumwsq) / (1 + n)
            assert a*a < b
            gammastar = (b - a) / (a*a - b)
            betastar = (1 - a) / (a*a - b)
            gstar = (n + 1) * (a - 1)**2 / (b - a*a)

        vhat = (-gammastar * sumwr - betastar * sumwsqr) / (1 + n)
        missing = max(0, 1 - (-gammastar * sumw - betastar * sumwsq) / (1 + n))

        if censored:
            # vhat = E[w r 1_{r is not None}] / E[w 1_{r is not None}]

            vnumhat = vhat
            vdenomhat = (-gammastar * sumwany - betastar * sumwsqany) / (1 + n)

            # Minimize[{ (x + a r) / (y + a),  y >= 0, 0 <= a <= m }, a]
            #
            # extrema is always at endpoints of interval

            vmincandidates = []
            vmaxcandidates = []
            if vdenomhat > 0:
                vmincandidates.append(vnumhat / vdenomhat)
                vmaxcandidates.append(vnumhat / vdenomhat)

            if vdenomhat + missing > 0:
                vmincandidates.append((vnumhat + missing * rmin) / (vdenomhat + missing))
                vmaxcandidates.append((vnumhat + missing * rmax) / (vdenomhat + missing))

            vmin = min(vmincandidates, default=None)
            vmax = max(vmaxcandidates, default=None)
            vhat = None if vmin is None or vmax is None else (vmin + vmax) / 2
        else:
            vmin = vhat + missing * rmin
            vmax = vhat + missing * rmax
            vhat += missing * (rmin + rmax) / 2

        vmin, vmax, vhat = (None if x is None else min(rmax, max(rmin, x))
                            for x in (vmin, vmax, vhat))

        return vhat, {
            'primal': gstar,
            'gammastar': gammastar,
            'betastar': betastar,
            'vmin': vmin,
            'vmax': vmax,
            'num': n,
            'qfunc': lambda c, w, r: (c/(1 + n)) * (-gammastar - betastar * w),
        }

    @staticmethod
    def estimatediff(datagen, umin, umax, wmin, wmax, rmin=0, rmax=1, raiseonerr=False, censored=False):
        import numpy as np

        assert umin >= 0
        assert umin < 1
        assert umax > 1
        assert wmin >= 0
        assert wmin < 1
        assert wmax > 1
        assert rmax >= rmin
        assert not censored

        n, sumu, sumw, sumuw, sumusq, sumwsq = 0, 0, 0, 0, 0, 0
        sumuMwr, sumuuMwr, sumwuMwr = 0, 0, 0
        for c, u, w, r in datagen():
            n += c
            sumu += c * u
            sumw += c * w
            sumuw += c * u * w
            sumusq += c * u**2
            sumwsq += c * w**2
            sumuMwr += c * (u - w) * r
            sumuuMwr += c * u * (u - w) * r
            sumwuMwr += c * w * (u - w) * r

        assert n > 0

        ufake = umax if sumu < n else umin
        wfake = wmax if sumw < n else wmin

        ubar = (sumu + ufake) / (n + 1)
        usqbar = (sumusq + ufake**2) / (n + 1)
        uwbar = (sumuw + ufake * wfake) / (n + 1)
        wbar = (sumw + wfake) / (n + 1)
        wsqbar = (sumwsq + wfake**2) / (n + 1)

        A = np.array([ [ -1, -ubar, -wbar ],
                       [ -ubar, -usqbar, -uwbar  ],
                       [ -wbar, -uwbar, -wsqbar ] ],
                     dtype='float64')
        b = np.ones(3, dtype='float64')

        xstar = np.linalg.lstsq(A, b, rcond=-1)[0]
        beta, gamma, tau = xstar

        deltavhat = (- beta * sumuMwr - gamma * sumuuMwr - tau * sumwuMwr) / n
        missing = (
                    - beta * (ufake - wfake)
                    - gamma * (ufake**2 - ufake * wfake)
                    - tau * (ufake * wfake - wfake**2)
                  ) / (n + 1)

        deltavmin = deltavhat + min(rmin * missing, rmax * missing)
        deltavmax = deltavhat + max(rmin * missing, rmax * missing)
        deltavhat = (deltavmin + deltavmax) / 2

        deltavmin, deltavmax, deltavhat = (min(rmax - rmin, max(rmin - rmax, x))
                                           for x in (deltavmin,
                                                     deltavmax,
                                                     deltavhat))

        qfunc = lambda c, u, w, r, n=n, b=beta, g=gamma, t=tau: (c/(n+1))*(-b -g*u -t*w)

        return deltavhat, {
                'deltavmin': deltavmin,
                'deltavmax': deltavmax,
                'num': n,
                'betastar': beta,
                'gammastar': gamma,
                'taustar': tau,
                'primal': -(n+1)*(1 + beta + gamma + tau),
                'qfunc': qfunc,
        }

    @staticmethod
    def interval(datagen, wmin, wmax, alpha=0.05,
                 rmin=0, rmax=1, raiseonerr=False):
        from math import inf, isclose, sqrt
        from scipy.stats import f

        assert wmin < 1
        assert wmax > 1
        assert rmin <= rmax

        n, sumw, sumwsq, sumwr, sumwsqr, sumwsqrsq = 0, 0, 0, 0, 0, 0

        for c, w, r in datagen():
            n += c
            sumw += c * w
            sumwsq += c * w**2
            sumwr += c * w * r
            sumwsqr += c * w**2 * r
            sumwsqrsq += c * w**2 * r**2
        assert n > 0

        return CrMinusTwo.intervalimpl(n, sumw, sumwsq, sumwr, sumwsqr, sumwsqrsq, wmin, wmax, alpha, rmin, rmax, raiseonerr)

    @staticmethod
    def intervalimpl(n, sumw, sumwsq, sumwr, sumwsqr, sumwsqrsq,
                     wmin, wmax, alpha=0.05,
                     rmin=0, rmax=1, raiseonerr=False):
        from math import inf, isclose, sqrt
        from scipy.stats import f

        assert wmin < 1
        assert wmax > 1
        assert rmin <= rmax

        uncwfake = wmax if sumw < n else wmin
        if uncwfake == inf:
            uncgstar = 1 + 1 / n
        else:
            unca = (uncwfake + sumw) / (1 + n)
            uncb = (uncwfake**2 + sumwsq) / (1 + n)
            uncgstar = (n + 1) * (unca - 1)**2 / (uncb - unca*unca)
        Delta = f.isf(q=alpha, dfn=1, dfd=n)
        phi = (-uncgstar - Delta) / (2 * (n + 1))

        bounds = []
        for r, sign in ((rmin, 1), (rmax, -1)):
            candidates = []
            for wfake in (wmin, wmax):
                if wfake == inf:
                    x = sign * (r + (sumwr - sumw * r) / n)
                    y = (  (r * sumw - sumwr)**2 / (n * (1 + n))
                         - (r**2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq) / (1 + n)
                        )
                    z = phi + 1 / (2 * n)
                    if isclose(y*z, 0, abs_tol=1e-9):
                        y = 0

                    if z <= 0 and y * z >= 0:
                        kappa = sqrt(y / (2 * z))
                        if isclose(kappa, 0):
                            candidates.append((sign * r, None))
                        else:
                            gstar = x - sqrt(2 * y * z)
                            gamma = ( -kappa * (1 + n) / n
                                     + sign * (r * sumw - sumwr) / n )
                            beta = -sign * r
                            candidates.append((gstar, {
                                'kappastar': kappa,
                                'betastar': beta,
                                'gammastar': gamma,
                                'wfake': wfake,
                            # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                            }))
                else:
                    barw = (wfake + sumw) / (1 + n)
                    barwsq = (wfake*wfake + sumwsq) / (1 + n)
                    barwr = sign * (wfake * r + sumwr) / (1 + n)
                    barwsqr = sign * (wfake * wfake * r + sumwsqr) / (1 + n)
                    barwsqrsq = (wfake * wfake * r * r + sumwsqrsq) / (1 + n)

                    if barwsq > barw**2:
                        x = barwr + ((1 - barw) * (barwsqr - barw * barwr) / (barwsq - barw**2))
                        y = (barwsqr - barw * barwr)**2 / (barwsq - barw**2) - (barwsqrsq - barwr**2)
                        z = phi + (1/2) * (1 - barw)**2 / (barwsq - barw**2)
                        if isclose(y*z, 0, abs_tol=1e-9):
                            y = 0

                        if z <= 0 and y * z >= 0:
                            kappa = sqrt(y / (2 * z)) if y * z > 0 else 0
                            if isclose(kappa, 0):
                                candidates.append((sign * r, None))
                            else:
                                gstar = x - sqrt(2 * y * z)
                                beta = (-kappa * (1 - barw) - (barwsqr - barw * barwr)) / (barwsq - barw*barw)
                                gamma = -kappa - beta * barw - barwr
                                candidates.append((gstar, {
                                    'kappastar': kappa,
                                    'betastar': beta,
                                    'gammastar': gamma,
                                    'wfake': wfake,
                                # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                    'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                                }))

            best = min(candidates, key=lambda x: x[0])
            vbound = min(rmax, max(rmin, sign*best[0]))
            bounds.append((vbound, best[1]))

        return (bounds[0][0], bounds[1][0]), (bounds[0][1], bounds[1][1])

    @staticmethod
    def intervaldiff(datagen, umin, umax, wmin, wmax, alpha=0.05,
                     rmin=0, rmax=1, raiseonerr=False):
        import numpy as np
        from math import isclose, sqrt
        from scipy.stats import f

        assert umin >= 0
        assert umin < 1
        assert umax > 1
        assert wmin >= 0
        assert wmin < 1
        assert wmax > 1
        assert rmax >= rmin

        _, mle = CrMinusTwo.estimatediff(datagen, umin, umax, wmin, wmax, rmin, rmax, raiseonerr=raiseonerr)

        Delta = f.isf(q=alpha, dfn=1, dfd=mle['num']-1)
        phi = (-Delta - mle['primal']) / (2 * (mle['num'] + 1))

        n, sumu, sumw, sumuw, sumusq, sumwsq = 0, 0, 0, 0, 0, 0
        sumuMwr, sumuuMwr, sumwuMwr, sumuMwsqrsq = 0, 0, 0, 0
        for c, u, w, r in datagen():
            n += c
            sumu += c * u
            sumw += c * w
            sumuw += c * u * w
            sumusq += c * u**2
            sumwsq += c * w**2
            sumuMwr += c * (u - w) * r
            sumuuMwr += c * u * (u - w) * r
            sumwuMwr += c * w * (u - w) * r
            sumuMwsqrsq += c * (u - w)**2 * r**2

        assert n > 0

        bounds = []
        for sign in (1, -1):
            candidates = []
            for ufake, wfake in ((u, w) for u in (umin, umax)
                                        for w in (wmin, wmax)):
                rex = rmin if sign * ufake >= sign * wfake else rmax

                baru = (sumu + ufake) / (n + 1)
                barw = (sumw + wfake) / (n + 1)
                barusq = (sumusq + ufake**2) / (n + 1)
                barwsq = (sumwsq + wfake**2) / (n + 1)
                baruw = (sumuw + ufake * wfake) / (n + 1)
                baruMwr = sign * (sumuMwr + (ufake - wfake) * rex) / (n + 1)
                baruuMwr = sign * (sumuuMwr + ufake * (ufake - wfake) * rex) / (n + 1)
                barwuMwr = sign * (sumwuMwr + wfake * (ufake - wfake) * rex) / (n + 1)
                baruMwsqrsq = (sumuMwsqrsq + (ufake - wfake)**2 * rex**2) / (n + 1)

                C = np.array([ [ 1.0, baru, barw ],
                               [ baru, barusq, baruw ],
                               [ barw, baruw, barwsq ],
                             ], dtype='float64')
                d = np.array([ baruMwr, baruuMwr, barwuMwr ],
                             dtype='float64')

                a = np.linalg.lstsq(C, np.ones(3), rcond=-1)[0]
                b = np.linalg.lstsq(C, d, rcond=-1)[0]
                x = np.sum(b)
                y = np.dot(d, b) - baruMwsqrsq
                z = phi - 0.5 + 0.5 * np.sum(a)

                if isclose(y*z, 0, abs_tol=1e-9):
                    y = 0

                if z <= 0 and y * z >= 0:
                    gstar = x - sqrt(2 * y * z)
                    kappa = sqrt(y / (2 * z)) if y * z > 0 else 0
                    beta, gamma, tau = -kappa * a - b

                    candidates.append((gstar, None if isclose(kappa, 0) else {
                            'kappastar': kappa,
                            'betastar': beta,
                            'gammastar': gamma,
                            'taustar': tau,
                            'ufake': ufake,
                            'wfake': wfake,
                            'rfake': rex,
                            'qfunc': lambda c, u, w, r, k=kappa, g=gamma, b=beta, t=tau, s=sign, num=n: -c * (b + g * u + t * w + s * (u - w) * r) / ((num + 1) * k),
                            'mle': mle,
                        }))

            best = min(candidates, key=lambda x: x[0])
            vbound = min(rmax - rmin, max(rmin - rmax, sign*best[0]))
            bounds.append((vbound, best[1]))

        return (bounds[0][0], bounds[1][0]), (bounds[0][1], bounds[1][1]), candidates
