class OnlineDRO:
    class OnlineCressieReadLB:
        from math import inf
        
        @staticmethod
        def intervalimpl(n, sumw, sumwsq, sumwr, sumwsqr, sumwsqrsq,
                         wmin, wmax, alpha=0.05,
                         rmin=0, rmax=1, raiseonerr=False):
            from math import inf, isclose, sqrt
            from scipy.stats import chi2

            assert wmin < 1
            assert wmax > 1
            assert rmin <= rmax

            uncwfake = wmax if sumw < n else wmin
            if uncwfake == inf:
                uncgstar = 1 + 1 / n
            else:
                unca = (uncwfake + sumw) / (1 + n)
                uncb = (uncwfake**2 + sumwsq) / (1 + n)
                uncgstar = (n + 1) * (unca - 1)**2 / (uncb - unca*unca)
            Delta = chi2.isf(q=alpha, df=1)
            phi = (-uncgstar - Delta) / (2 * (n + 1))

            bounds = []
            for r, sign in ((rmin, 1),):
                candidates = []
                for wfake in (wmin, wmax):
                    if wfake == inf:
                        x = sign * (r + (sumwr - sumw * r) / n)
                        y = (  (r * sumw - sumwr)**2 / (n * (1 + n))
                             - (r**2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq) / (1 + n)
                            )
                        z = phi + 1 / (2 * n)
                        if isclose(y*z, 0, abs_tol=1e-9):
                            y = 0

                        if z <= 0 and y * z >= 0:
                            kappa = sqrt(y / (2 * z))
                            if isclose(kappa, 0):
                                candidates.append((sign * r, None))
                            else:
                                gstar = x - sqrt(2 * y * z)
                                gamma = ( -kappa * (1 + n) / n
                                         + sign * (r * sumw - sumwr) / n )
                                beta = -sign * r
                                candidates.append((gstar, {
                                    'kappastar': kappa,
                                    'betastar': beta,
                                    'gammastar': gamma,
                                    'wfake': wfake,
                                # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                    'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                                }))
                    else:
                        barw = (wfake + sumw) / (1 + n)
                        barwsq = (wfake*wfake + sumwsq) / (1 + n)
                        barwr = sign * (wfake * r + sumwr) / (1 + n)
                        barwsqr = sign * (wfake * wfake * r + sumwsqr) / (1 + n)
                        barwsqrsq = (wfake * wfake * r * r + sumwsqrsq) / (1 + n)
                        
                        if barwsq > barw**2:
                            x = barwr + ((1 - barw) * (barwsqr - barw * barwr) / (barwsq - barw**2))
                            y = (barwsqr - barw * barwr)**2 / (barwsq - barw**2) - (barwsqrsq - barwr**2)
                            z = phi + (1/2) * (1 - barw)**2 / (barwsq - barw**2)
                            
                            if isclose(y*z, 0, abs_tol=1e-9):
                                y = 0

                            if z <= 0 and y * z >= 0:
                                kappa = sqrt(y / (2 * z)) if y * z > 0 else 0
                                if isclose(kappa, 0):
                                    candidates.append((sign * r, None))
                                else:
                                    gstar = x - sqrt(2 * y * z)
                                    beta = (-kappa * (1 - barw) - (barwsqr - barw * barwr)) / (barwsq - barw*barw)
                                    gamma = -kappa - beta * barw - barwr
                                    
                                    candidates.append((gstar, {
                                        'kappastar': kappa,
                                        'betastar': beta,
                                        'gammastar': gamma,
                                        'wfake': wfake,
                                    # Q_{w,r} &= -\frac{\gamma + \beta w + w r}{(N+1) \kappa} \\
                                        'qfunc': lambda c, w, r, k=kappa, g=gamma, b=beta, s=sign, num=n: -c * (g + (b + s * r) * w) / ((num + 1) * k),
                                    }))

                best = min(candidates, key=lambda x: x[0])
                vbound = min(rmax, max(rmin, sign*best[0]))
                bounds.append((vbound, best[1]))

            return (bounds[0][0], ), (bounds[0][1], )
        
        def __init__(self, alpha, tau=1, wmin=0, wmax=inf):
            import numpy as np
            
            self.alpha = alpha
            self.tau = tau
            self.n = 0
            self.sumw = 0
            self.sumwsq = 0
            self.sumwr = 0
            self.sumwsqr = 0
            self.sumwsqrsq = 0
            self.wmin = wmin
            self.wmax = wmax
            
            self.duals = None
            self.mleduals = None
            
        def update(self, c, w, r):
            if c > 0:
                assert w + 1e-6 >= self.wmin and w <= self.wmax + 1e-6, 'w = {} < {} < {}'.format(self.wmin, w, self.wmax)
                assert r >= 0 and r <= 1, 'r = {}'.format(r)
                
                decay = self.tau ** c
                self.n = decay * self.n + c
                self.sumw = decay * self.sumw + c * w
                self.sumwsq = decay * self.sumwsq + c * w**2
                self.sumwr = decay * self.sumwr + c * w * r
                self.sumwsqr = decay * self.sumwsqr + c * (w**2) * r
                self.sumwsqrsq = decay * self.sumwsqrsq + c * (w**2) * (r**2)
                    
                self.duals = None
                self.mleduals = None
                
            return self
        
        def recomputeduals(self):
            from crminustwo import CrMinusTwo as CrMinusTwo
            
            self.duals = self.intervalimpl(self.n, self.sumw, self.sumwsq, 
                                           self.sumwr, self.sumwsqr, self.sumwsqrsq,
                                           self.wmin, self.wmax, self.alpha, raiseonerr=True)
            
        def qlb(self, w, r):
            if self.duals is None:
                self.recomputeduals()
                
                assert self.duals is not None
                
            return self.duals[1][0]['qfunc'](1, w, r) if self.duals[1][0] is not None else 1
            
    def recompute_duals_test():
        from math import exp, pi
        from pprint import pformat
        import numpy
        
        ocrl = OnlineDRO.OnlineCressieReadLB(alpha=0.05, tau=0.999)
        
        ws = numpy.random.RandomState(seed=42).exponential(size=10)
        rs = numpy.random.RandomState(seed=2112).random_sample(size=10)
        duals = []
        
        print(list(zip(ws, rs)))
        
        for (w, r) in zip(ws, rs):
            ocrl.update(1, w, r)
            ocrl.recomputeduals()
            if ocrl.duals[1][0] is None:
                duals.append( ( True, 0, 0, 0, 0 ) )
            else:               
                duals.append( ( False, ocrl.duals[1][0]['kappastar'], 
                               ocrl.duals[1][0]['gammastar'], ocrl.duals[1][0]['betastar'], ocrl.n ) )
            
        print(pformat(duals))
        
    def qlb_test():
        from math import exp, pi
        from pprint import pformat
        import numpy
        
        ocrl = OnlineDRO.OnlineCressieReadLB(alpha=0.05, tau=0.999)
        
        ws = numpy.random.RandomState(seed=42).exponential(size=10)
        rs = numpy.random.RandomState(seed=2112).random_sample(size=10)
        qlbs = []
        
        print(list(zip(ws, rs)))
        
        for (w, r) in zip(ws, rs):
            ocrl.update(1, w, r)
            qlbs.append(ocrl.qlb(w, r))
            
        print(pformat(qlbs))
        
OnlineDRO.recompute_duals_test()
OnlineDRO.qlb_test()
